prasb commited on
Commit
89b4844
·
verified ·
1 Parent(s): ab75471

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/imageio/resources/images/stent.npz +3 -0
  2. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py +56 -0
  3. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/base.py +368 -0
  4. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/device_stats_monitor.py +104 -0
  5. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/early_stopping.py +261 -0
  6. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py +417 -0
  7. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/gpu_stats_monitor.py +262 -0
  8. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lambda_function.py +96 -0
  9. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lr_monitor.py +354 -0
  10. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py +720 -0
  11. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_summary.py +73 -0
  12. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/prediction_writer.py +119 -0
  13. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py +486 -0
  14. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/quantization.py +344 -0
  15. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/rich_model_summary.py +109 -0
  16. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py +280 -0
  17. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/timer.py +176 -0
  18. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/xla_stats_monitor.py +114 -0
  19. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py +264 -0
  20. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/decorators.py +60 -0
  21. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/hooks.py +828 -0
  22. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py +409 -0
  23. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py +419 -0
  24. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/__init__.py +14 -0
  25. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/dist.py +47 -0
  26. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/__init__.cpython-38.pyc +0 -0
  27. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/layer_sync.cpython-38.pyc +0 -0
  28. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__init__.py +20 -0
  29. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/__init__.cpython-38.pyc +0 -0
  30. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/bagua_environment.cpython-38.pyc +0 -0
  31. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/bagua_environment.py +62 -0
  32. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/cluster_environment.py +87 -0
  33. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/kubeflow_environment.py +78 -0
  34. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lightning_environment.py +101 -0
  35. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lsf_environment.py +190 -0
  36. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/slurm_environment.py +134 -0
  37. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/torchelastic_environment.py +88 -0
  38. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__init__.py +17 -0
  39. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__pycache__/xla_plugin.cpython-38.pyc +0 -0
  40. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/checkpoint_plugin.py +62 -0
  41. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/hpu_plugin.py +52 -0
  42. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py +96 -0
  43. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/xla_plugin.py +57 -0
  44. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__init__.py +27 -0
  45. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/apex_amp.cpython-38.pyc +0 -0
  46. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/deepspeed.cpython-38.pyc +0 -0
  47. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/double.cpython-38.pyc +0 -0
  48. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/fully_sharded_native_amp.cpython-38.pyc +0 -0
  49. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/mixed.cpython-38.pyc +0 -0
  50. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/native_amp.cpython-38.pyc +0 -0
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/imageio/resources/images/stent.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60a83d2296b51ee6a53153e9ba96ba9020391b0c8952895d9d60a0a629ac6bb6
3
+ size 824612
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from pytorch_lightning.callbacks.base import Callback
15
+ from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
16
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
17
+ from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
18
+ from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
19
+ from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
20
+ from pytorch_lightning.callbacks.lambda_function import LambdaCallback
21
+ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
22
+ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
23
+ from pytorch_lightning.callbacks.model_summary import ModelSummary
24
+ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
25
+ from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar
26
+ from pytorch_lightning.callbacks.pruning import ModelPruning
27
+ from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
28
+ from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
29
+ from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
30
+ from pytorch_lightning.callbacks.timer import Timer
31
+ from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
32
+
33
+ __all__ = [
34
+ "BackboneFinetuning",
35
+ "BaseFinetuning",
36
+ "Callback",
37
+ "DeviceStatsMonitor",
38
+ "EarlyStopping",
39
+ "GPUStatsMonitor",
40
+ "XLAStatsMonitor",
41
+ "GradientAccumulationScheduler",
42
+ "LambdaCallback",
43
+ "LearningRateMonitor",
44
+ "ModelCheckpoint",
45
+ "ModelPruning",
46
+ "ModelSummary",
47
+ "BasePredictionWriter",
48
+ "ProgressBar",
49
+ "ProgressBarBase",
50
+ "QuantizationAwareTraining",
51
+ "RichModelSummary",
52
+ "RichProgressBar",
53
+ "StochasticWeightAveraging",
54
+ "Timer",
55
+ "TQDMProgressBar",
56
+ ]
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/base.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ Base class used to build new callbacks.
16
+
17
+ """
18
+
19
+ from typing import Any, Dict, List, Optional, Type
20
+
21
+ import torch
22
+ from torch.optim import Optimizer
23
+
24
+ import pytorch_lightning as pl
25
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
26
+
27
+
28
+ class Callback:
29
+ r"""
30
+ Abstract base class used to build new callbacks.
31
+
32
+ Subclass this class and override any of the relevant hooks
33
+ """
34
+
35
+ @property
36
+ def state_key(self) -> str:
37
+ """Identifier for the state of the callback.
38
+
39
+ Used to store and retrieve a callback's state from the checkpoint dictionary by
40
+ ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1)
41
+ the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.
42
+ """
43
+ return self.__class__.__qualname__
44
+
45
+ @property
46
+ def _legacy_state_key(self) -> Type["Callback"]:
47
+ """State key for checkpoints saved prior to version 1.5.0."""
48
+ return type(self)
49
+
50
+ def _generate_state_key(self, **kwargs: Any) -> str:
51
+ """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful
52
+ for defining a :attr:`state_key`.
53
+
54
+ Args:
55
+ **kwargs: A set of key-value pairs. Must be serializable to :class:`str`.
56
+ """
57
+ return f"{self.__class__.__qualname__}{repr(kwargs)}"
58
+
59
+ def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
60
+ r"""
61
+ .. deprecated:: v1.6
62
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use `setup()` instead.
63
+
64
+ Called before configure sharded model.
65
+ """
66
+
67
+ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
68
+ r"""
69
+ .. deprecated:: v1.6
70
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``setup()`` instead.
71
+
72
+ Called before accelerator is being setup.
73
+ """
74
+
75
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
76
+ """Called when fit, validate, test, predict, or tune begins."""
77
+
78
+ def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
79
+ """Called when fit, validate, test, predict, or tune ends."""
80
+
81
+ def on_init_start(self, trainer: "pl.Trainer") -> None:
82
+ r"""
83
+ .. deprecated:: v1.6
84
+ This callback hook was deprecated in v1.6 and will be removed in v1.8.
85
+
86
+ Called when the trainer initialization begins, model has not yet been set.
87
+ """
88
+
89
+ def on_init_end(self, trainer: "pl.Trainer") -> None:
90
+ r"""
91
+ .. deprecated:: v1.6
92
+ This callback hook was deprecated in v1.6 and will be removed in v1.8.
93
+
94
+ Called when the trainer initialization ends, model has not yet been set.
95
+ """
96
+
97
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
98
+ """Called when fit begins."""
99
+
100
+ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
101
+ """Called when fit ends."""
102
+
103
+ def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
104
+ """Called when the validation sanity check starts."""
105
+
106
+ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
107
+ """Called when the validation sanity check ends."""
108
+
109
+ def on_train_batch_start(
110
+ self,
111
+ trainer: "pl.Trainer",
112
+ pl_module: "pl.LightningModule",
113
+ batch: Any,
114
+ batch_idx: int,
115
+ unused: int = 0,
116
+ ) -> None:
117
+ """Called when the train batch begins."""
118
+
119
+ def on_train_batch_end(
120
+ self,
121
+ trainer: "pl.Trainer",
122
+ pl_module: "pl.LightningModule",
123
+ outputs: STEP_OUTPUT,
124
+ batch: Any,
125
+ batch_idx: int,
126
+ unused: int = 0,
127
+ ) -> None:
128
+ """Called when the train batch ends."""
129
+
130
+ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
131
+ """Called when the train epoch begins."""
132
+
133
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
134
+ """Called when the train epoch ends.
135
+
136
+ To access all batch outputs at the end of the epoch, either:
137
+
138
+ 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
139
+ 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
140
+ """
141
+
142
+ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
143
+ """Called when the val epoch begins."""
144
+
145
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
146
+ """Called when the val epoch ends."""
147
+
148
+ def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
149
+ """Called when the test epoch begins."""
150
+
151
+ def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
152
+ """Called when the test epoch ends."""
153
+
154
+ def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
155
+ """Called when the predict epoch begins."""
156
+
157
+ def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: List[Any]) -> None:
158
+ """Called when the predict epoch ends."""
159
+
160
+ def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
161
+ r"""
162
+ .. deprecated:: v1.6
163
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
164
+ ``on_<train/validation/test>_epoch_start`` instead.
165
+
166
+ Called when either of train/val/test epoch begins.
167
+ """
168
+
169
+ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
170
+ r"""
171
+ .. deprecated:: v1.6
172
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
173
+ ``on_<train/validation/test>_epoch_end`` instead.
174
+
175
+ Called when either of train/val/test epoch ends.
176
+ """
177
+
178
+ def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
179
+ r"""
180
+ .. deprecated:: v1.6
181
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
182
+ ``on_train_batch_start`` instead.
183
+
184
+ Called when the training batch begins.
185
+ """
186
+
187
+ def on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
188
+ r"""
189
+ .. deprecated:: v1.6
190
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
191
+ ``on_train_batch_end`` instead.
192
+
193
+ Called when the training batch ends.
194
+ """
195
+
196
+ def on_validation_batch_start(
197
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
198
+ ) -> None:
199
+ """Called when the validation batch begins."""
200
+
201
+ def on_validation_batch_end(
202
+ self,
203
+ trainer: "pl.Trainer",
204
+ pl_module: "pl.LightningModule",
205
+ outputs: Optional[STEP_OUTPUT],
206
+ batch: Any,
207
+ batch_idx: int,
208
+ dataloader_idx: int,
209
+ ) -> None:
210
+ """Called when the validation batch ends."""
211
+
212
+ def on_test_batch_start(
213
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
214
+ ) -> None:
215
+ """Called when the test batch begins."""
216
+
217
+ def on_test_batch_end(
218
+ self,
219
+ trainer: "pl.Trainer",
220
+ pl_module: "pl.LightningModule",
221
+ outputs: Optional[STEP_OUTPUT],
222
+ batch: Any,
223
+ batch_idx: int,
224
+ dataloader_idx: int,
225
+ ) -> None:
226
+ """Called when the test batch ends."""
227
+
228
+ def on_predict_batch_start(
229
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
230
+ ) -> None:
231
+ """Called when the predict batch begins."""
232
+
233
+ def on_predict_batch_end(
234
+ self,
235
+ trainer: "pl.Trainer",
236
+ pl_module: "pl.LightningModule",
237
+ outputs: Any,
238
+ batch: Any,
239
+ batch_idx: int,
240
+ dataloader_idx: int,
241
+ ) -> None:
242
+ """Called when the predict batch ends."""
243
+
244
+ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
245
+ """Called when the train begins."""
246
+
247
+ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
248
+ """Called when the train ends."""
249
+
250
+ def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
251
+ r"""
252
+ .. deprecated:: v1.6
253
+
254
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
255
+
256
+ Called when the pretrain routine begins.
257
+ """
258
+
259
+ def on_pretrain_routine_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
260
+ r"""
261
+ .. deprecated:: v1.6
262
+
263
+ This callback hook was deprecated in v1.6 and will be removed in v1.8. Use ``on_fit_start`` instead.
264
+
265
+ Called when the pretrain routine ends.
266
+ """
267
+
268
+ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
269
+ """Called when the validation loop begins."""
270
+
271
+ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
272
+ """Called when the validation loop ends."""
273
+
274
+ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
275
+ """Called when the test begins."""
276
+
277
+ def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
278
+ """Called when the test ends."""
279
+
280
+ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
281
+ """Called when the predict begins."""
282
+
283
+ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
284
+ """Called when predict ends."""
285
+
286
+ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
287
+ r"""
288
+ .. deprecated:: v1.5
289
+ This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
290
+
291
+ Called when any trainer execution is interrupted by KeyboardInterrupt.
292
+ """
293
+
294
+ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
295
+ """Called when any trainer execution is interrupted by an exception."""
296
+
297
+ def state_dict(self) -> Dict[str, Any]:
298
+ """Called when saving a checkpoint, implement to generate callback's ``state_dict``.
299
+
300
+ Returns:
301
+ A dictionary containing callback state.
302
+ """
303
+ return {}
304
+
305
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
306
+ """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``.
307
+
308
+ Args:
309
+ state_dict: the callback state returned by ``state_dict``.
310
+ """
311
+ pass
312
+
313
+ def on_save_checkpoint(
314
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
315
+ ) -> Optional[dict]:
316
+ r"""
317
+ Called when saving a checkpoint to give you a chance to store anything else you might want to save.
318
+
319
+ Args:
320
+ trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
321
+ pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
322
+ checkpoint: the checkpoint dictionary that will be saved.
323
+
324
+ Returns:
325
+ None or the callback state. Support for returning callback state will be removed in v1.8.
326
+
327
+ .. deprecated:: v1.6
328
+ Returning a value from this method was deprecated in v1.6 and will be removed in v1.8.
329
+ Implement ``Callback.state_dict`` instead to return state.
330
+ In v1.8 ``Callback.on_save_checkpoint`` can only return None.
331
+ """
332
+
333
+ def on_load_checkpoint(
334
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
335
+ ) -> None:
336
+ r"""
337
+ Called when loading a model checkpoint, use to reload state.
338
+
339
+ Args:
340
+ trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
341
+ pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
342
+ callback_state: the callback state returned by ``on_save_checkpoint``.
343
+
344
+ Note:
345
+ The ``on_load_checkpoint`` won't be called with an undefined state.
346
+ If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
347
+ you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
348
+
349
+ .. deprecated:: v1.6
350
+ This callback hook will change its signature and behavior in v1.8.
351
+ If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead.
352
+ In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded
353
+ checkpoint dictionary instead of only the callback state from the checkpoint.
354
+ """
355
+
356
+ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor) -> None:
357
+ """Called before ``loss.backward()``."""
358
+
359
+ def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
360
+ """Called after ``loss.backward()`` and before optimizers are stepped."""
361
+
362
+ def on_before_optimizer_step(
363
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer, opt_idx: int
364
+ ) -> None:
365
+ """Called before ``optimizer.step()``."""
366
+
367
+ def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None:
368
+ """Called before ``optimizer.zero_grad()``."""
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/device_stats_monitor.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Device Stats Monitor
16
+ ====================
17
+
18
+ Monitors and logs device stats during training.
19
+
20
+ """
21
+ from typing import Any, Dict, Optional
22
+
23
+ import pytorch_lightning as pl
24
+ from pytorch_lightning.callbacks.base import Callback
25
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
26
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
27
+ from pytorch_lightning.utilities.warnings import rank_zero_deprecation
28
+
29
+
30
+ class DeviceStatsMonitor(Callback):
31
+ r"""
32
+ Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor``
33
+ is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.
34
+
35
+ Raises:
36
+ MisconfigurationException:
37
+ If ``Trainer`` has no logger.
38
+
39
+ Example:
40
+ >>> from pytorch_lightning import Trainer
41
+ >>> from pytorch_lightning.callbacks import DeviceStatsMonitor
42
+ >>> device_stats = DeviceStatsMonitor() # doctest: +SKIP
43
+ >>> trainer = Trainer(callbacks=[device_stats]) # doctest: +SKIP
44
+ """
45
+
46
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
47
+ if not trainer.loggers:
48
+ raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")
49
+
50
+ def on_train_batch_start(
51
+ self,
52
+ trainer: "pl.Trainer",
53
+ pl_module: "pl.LightningModule",
54
+ batch: Any,
55
+ batch_idx: int,
56
+ unused: int = 0,
57
+ ) -> None:
58
+ if not trainer.loggers:
59
+ raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
60
+
61
+ if not trainer._logger_connector.should_update_logs:
62
+ return
63
+
64
+ device = trainer.strategy.root_device
65
+ device_stats = trainer.accelerator.get_device_stats(device)
66
+ for logger in trainer.loggers:
67
+ separator = logger.group_separator
68
+ prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
69
+ logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
70
+
71
+ def on_train_batch_end(
72
+ self,
73
+ trainer: "pl.Trainer",
74
+ pl_module: "pl.LightningModule",
75
+ outputs: STEP_OUTPUT,
76
+ batch: Any,
77
+ batch_idx: int,
78
+ unused: int = 0,
79
+ ) -> None:
80
+ if not trainer.loggers:
81
+ raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
82
+
83
+ if not trainer._logger_connector.should_update_logs:
84
+ return
85
+
86
+ device = trainer.strategy.root_device
87
+ device_stats = trainer.accelerator.get_device_stats(device)
88
+ for logger in trainer.loggers:
89
+ separator = logger.group_separator
90
+ prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
91
+ logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
92
+
93
+
94
+ def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
95
+ return {prefix + separator + k: v for k, v in metrics_dict.items()}
96
+
97
+
98
+ def prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
99
+ rank_zero_deprecation(
100
+ "`pytorch_lightning.callbacks.device_stats_monitor.prefix_metrics`"
101
+ " is deprecated in v1.6 and will be removed in v1.8."
102
+ )
103
+ sep = ""
104
+ return _prefix_metric_keys(metrics_dict, prefix, sep)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/early_stopping.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ Early Stopping
16
+ ^^^^^^^^^^^^^^
17
+
18
+ Monitor a metric and stop training when it stops improving.
19
+
20
+ """
21
+ import logging
22
+ from typing import Any, Callable, Dict, Optional, Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ import pytorch_lightning as pl
28
+ from pytorch_lightning.callbacks.base import Callback
29
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
30
+ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
31
+
32
+ log = logging.getLogger(__name__)
33
+
34
+
35
+ class EarlyStopping(Callback):
36
+ r"""
37
+ Monitor a metric and stop training when it stops improving.
38
+
39
+ Args:
40
+ monitor: quantity to be monitored.
41
+ min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
42
+ change of less than or equal to `min_delta`, will count as no improvement.
43
+ patience: number of checks with no improvement
44
+ after which training will be stopped. Under the default configuration, one check happens after
45
+ every training epoch. However, the frequency of validation can be modified by setting various parameters on
46
+ the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``.
47
+
48
+ .. note::
49
+
50
+ It must be noted that the patience parameter counts the number of validation checks with
51
+ no improvement, and not the number of training epochs. Therefore, with parameters
52
+ ``check_val_every_n_epoch=10`` and ``patience=3``, the trainer will perform at least 40 training
53
+ epochs before being stopped.
54
+
55
+ verbose: verbosity mode.
56
+ mode: one of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the quantity
57
+ monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity
58
+ monitored has stopped increasing.
59
+ strict: whether to crash the training if `monitor` is not found in the validation metrics.
60
+ check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
61
+ stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
62
+ divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
63
+ check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
64
+ If this is ``False``, then the check runs at the end of the validation.
65
+
66
+ Raises:
67
+ MisconfigurationException:
68
+ If ``mode`` is none of ``"min"`` or ``"max"``.
69
+ RuntimeError:
70
+ If the metric ``monitor`` is not available.
71
+
72
+ Example::
73
+
74
+ >>> from pytorch_lightning import Trainer
75
+ >>> from pytorch_lightning.callbacks import EarlyStopping
76
+ >>> early_stopping = EarlyStopping('val_loss')
77
+ >>> trainer = Trainer(callbacks=[early_stopping])
78
+
79
+ .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
80
+ following arguments:
81
+
82
+ *monitor, mode*
83
+
84
+ Read more: :ref:`Persisting Callback State`
85
+ """
86
+ mode_dict = {"min": torch.lt, "max": torch.gt}
87
+
88
+ order_dict = {"min": "<", "max": ">"}
89
+
90
+ def __init__(
91
+ self,
92
+ monitor: str,
93
+ min_delta: float = 0.0,
94
+ patience: int = 3,
95
+ verbose: bool = False,
96
+ mode: str = "min",
97
+ strict: bool = True,
98
+ check_finite: bool = True,
99
+ stopping_threshold: Optional[float] = None,
100
+ divergence_threshold: Optional[float] = None,
101
+ check_on_train_epoch_end: Optional[bool] = None,
102
+ ):
103
+ super().__init__()
104
+ self.monitor = monitor
105
+ self.min_delta = min_delta
106
+ self.patience = patience
107
+ self.verbose = verbose
108
+ self.mode = mode
109
+ self.strict = strict
110
+ self.check_finite = check_finite
111
+ self.stopping_threshold = stopping_threshold
112
+ self.divergence_threshold = divergence_threshold
113
+ self.wait_count = 0
114
+ self.stopped_epoch = 0
115
+ self._check_on_train_epoch_end = check_on_train_epoch_end
116
+
117
+ if self.mode not in self.mode_dict:
118
+ raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
119
+
120
+ self.min_delta *= 1 if self.monitor_op == torch.gt else -1
121
+ torch_inf = torch.tensor(np.Inf)
122
+ self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
123
+
124
+ @property
125
+ def state_key(self) -> str:
126
+ return self._generate_state_key(monitor=self.monitor, mode=self.mode)
127
+
128
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
129
+ if self._check_on_train_epoch_end is None:
130
+ # if the user runs validation multiple times per training epoch or multiple training epochs without
131
+ # validation, then we run after validation instead of on train epoch end
132
+ self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
133
+
134
+ def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
135
+ monitor_val = logs.get(self.monitor)
136
+
137
+ error_msg = (
138
+ f"Early stopping conditioned on metric `{self.monitor}` which is not available."
139
+ " Pass in or modify your `EarlyStopping` callback to use any of the following:"
140
+ f' `{"`, `".join(list(logs.keys()))}`'
141
+ )
142
+
143
+ if monitor_val is None:
144
+ if self.strict:
145
+ raise RuntimeError(error_msg)
146
+ if self.verbose > 0:
147
+ rank_zero_warn(error_msg, category=RuntimeWarning)
148
+
149
+ return False
150
+
151
+ return True
152
+
153
+ @property
154
+ def monitor_op(self) -> Callable:
155
+ return self.mode_dict[self.mode]
156
+
157
+ def state_dict(self) -> Dict[str, Any]:
158
+ return {
159
+ "wait_count": self.wait_count,
160
+ "stopped_epoch": self.stopped_epoch,
161
+ "best_score": self.best_score,
162
+ "patience": self.patience,
163
+ }
164
+
165
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
166
+ self.wait_count = state_dict["wait_count"]
167
+ self.stopped_epoch = state_dict["stopped_epoch"]
168
+ self.best_score = state_dict["best_score"]
169
+ self.patience = state_dict["patience"]
170
+
171
+ def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
172
+ from pytorch_lightning.trainer.states import TrainerFn
173
+
174
+ return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking
175
+
176
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
177
+ if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
178
+ return
179
+ self._run_early_stopping_check(trainer)
180
+
181
+ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
182
+ if self._check_on_train_epoch_end or self._should_skip_check(trainer):
183
+ return
184
+ self._run_early_stopping_check(trainer)
185
+
186
+ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
187
+ """Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
188
+ logs = trainer.callback_metrics
189
+
190
+ if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run
191
+ logs
192
+ ): # short circuit if metric not present
193
+ return
194
+
195
+ current = logs[self.monitor].squeeze()
196
+ should_stop, reason = self._evaluate_stopping_criteria(current)
197
+
198
+ # stop every ddp process if any world process decides to stop
199
+ should_stop = trainer.strategy.reduce_boolean_decision(should_stop)
200
+ trainer.should_stop = trainer.should_stop or should_stop
201
+ if should_stop:
202
+ self.stopped_epoch = trainer.current_epoch
203
+ if reason and self.verbose:
204
+ self._log_info(trainer, reason)
205
+
206
+ def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Optional[str]]:
207
+ should_stop = False
208
+ reason = None
209
+ if self.check_finite and not torch.isfinite(current):
210
+ should_stop = True
211
+ reason = (
212
+ f"Monitored metric {self.monitor} = {current} is not finite."
213
+ f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
214
+ )
215
+ elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
216
+ should_stop = True
217
+ reason = (
218
+ "Stopping threshold reached:"
219
+ f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
220
+ " Signaling Trainer to stop."
221
+ )
222
+ elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
223
+ should_stop = True
224
+ reason = (
225
+ "Divergence threshold reached:"
226
+ f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
227
+ " Signaling Trainer to stop."
228
+ )
229
+ elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
230
+ should_stop = False
231
+ reason = self._improvement_message(current)
232
+ self.best_score = current
233
+ self.wait_count = 0
234
+ else:
235
+ self.wait_count += 1
236
+ if self.wait_count >= self.patience:
237
+ should_stop = True
238
+ reason = (
239
+ f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
240
+ f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
241
+ )
242
+
243
+ return should_stop, reason
244
+
245
+ def _improvement_message(self, current: torch.Tensor) -> str:
246
+ """Formats a log message that informs the user about an improvement in the monitored score."""
247
+ if torch.isfinite(self.best_score):
248
+ msg = (
249
+ f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
250
+ f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
251
+ )
252
+ else:
253
+ msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
254
+ return msg
255
+
256
+ @staticmethod
257
+ def _log_info(trainer: Optional["pl.Trainer"], message: str) -> None:
258
+ if trainer is not None and trainer.world_size > 1:
259
+ log.info(f"[rank: {trainer.global_rank}] {message}")
260
+ else:
261
+ log.info(message)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ Finetuning Callback
16
+ ^^^^^^^^^^^^^^^^^^^^
17
+ Freeze and unfreeze models for finetuning purposes
18
+ """
19
+ import logging
20
+ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
21
+
22
+ import torch
23
+ from torch.nn import Module, ModuleDict
24
+ from torch.nn.modules.batchnorm import _BatchNorm
25
+ from torch.optim.optimizer import Optimizer
26
+
27
+ import pytorch_lightning as pl
28
+ from pytorch_lightning.callbacks.base import Callback
29
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
30
+ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
31
+
32
+ log = logging.getLogger(__name__)
33
+
34
+
35
+ def multiplicative(epoch):
36
+ return 2
37
+
38
+
39
+ class BaseFinetuning(Callback):
40
+ r"""
41
+ This class implements the base logic for writing your own Finetuning Callback.
42
+
43
+ Override ``freeze_before_training`` and ``finetune_function`` methods with your own logic.
44
+
45
+ ``freeze_before_training``: This method is called before ``configure_optimizers``
46
+ and should be used to freeze any modules parameters.
47
+
48
+ ``finetune_function``: This method is called on every train epoch start and should be used to
49
+ ``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group``
50
+ within the optimizer.
51
+
52
+ .. note:: Make sure to filter the parameters based on ``requires_grad``.
53
+
54
+ Example::
55
+
56
+ >>> from torch.optim import Adam
57
+ >>> class MyModel(pl.LightningModule):
58
+ ... def configure_optimizer(self):
59
+ ... # Make sure to filter the parameters based on `requires_grad`
60
+ ... return Adam(filter(lambda p: p.requires_grad, self.parameters()))
61
+ ...
62
+ >>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
63
+ ... def __init__(self, unfreeze_at_epoch=10):
64
+ ... super().__init__()
65
+ ... self._unfreeze_at_epoch = unfreeze_at_epoch
66
+ ...
67
+ ... def freeze_before_training(self, pl_module):
68
+ ... # freeze any module you want
69
+ ... # Here, we are freezing `feature_extractor`
70
+ ... self.freeze(pl_module.feature_extractor)
71
+ ...
72
+ ... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
73
+ ... # When `current_epoch` is 10, feature_extractor will start training.
74
+ ... if current_epoch == self._unfreeze_at_epoch:
75
+ ... self.unfreeze_and_add_param_group(
76
+ ... modules=pl_module.feature_extractor,
77
+ ... optimizer=optimizer,
78
+ ... train_bn=True,
79
+ ... )
80
+ """
81
+
82
+ def __init__(self):
83
+ self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {}
84
+ self._restarting = False
85
+
86
+ def state_dict(self) -> Dict[str, Any]:
87
+ return {
88
+ "internal_optimizer_metadata": self._internal_optimizer_metadata,
89
+ }
90
+
91
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
92
+ self._restarting = True
93
+ if "internal_optimizer_metadata" in state_dict:
94
+ self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"]
95
+ else:
96
+ # compatibility to load from old checkpoints before PR #11887
97
+ self._internal_optimizer_metadata = state_dict
98
+
99
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
100
+ # restore the param_groups created during the previous training.
101
+ if self._restarting:
102
+ named_parameters = dict(pl_module.named_parameters())
103
+ for opt_idx, optimizer in enumerate(trainer.optimizers):
104
+ param_groups = self._apply_mapping_to_param_groups(
105
+ self._internal_optimizer_metadata[opt_idx], named_parameters
106
+ )
107
+ optimizer.param_groups = param_groups
108
+ self._restarting = False
109
+
110
+ @staticmethod
111
+ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
112
+ """This function is used to flatten a module or an iterable of modules into a list of its leaf modules
113
+ (modules with no children) and parent modules that have parameters directly themselves.
114
+
115
+ Args:
116
+ modules: A given module or an iterable of modules
117
+
118
+ Returns:
119
+ List of modules
120
+ """
121
+ if isinstance(modules, ModuleDict):
122
+ modules = modules.values()
123
+
124
+ if isinstance(modules, Iterable):
125
+ _modules = []
126
+ for m in modules:
127
+ _modules.extend(BaseFinetuning.flatten_modules(m))
128
+
129
+ else:
130
+ _modules = modules.modules()
131
+
132
+ # Capture all leaf modules as well as parent modules that have parameters directly themselves
133
+ return [m for m in _modules if not list(m.children()) or m._parameters]
134
+
135
+ @staticmethod
136
+ def filter_params(
137
+ modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True
138
+ ) -> Generator:
139
+ """Yields the `requires_grad` parameters of a given module or list of modules.
140
+
141
+ Args:
142
+ modules: A given module or an iterable of modules
143
+ train_bn: Whether to train BatchNorm module
144
+ requires_grad: Whether to create a generator for trainable or non-trainable parameters.
145
+ Returns:
146
+ Generator
147
+ """
148
+ modules = BaseFinetuning.flatten_modules(modules)
149
+ for mod in modules:
150
+ if isinstance(mod, _BatchNorm) and not train_bn:
151
+ continue
152
+ # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
153
+ for param in mod.parameters(recurse=False):
154
+ if param.requires_grad == requires_grad:
155
+ yield param
156
+
157
+ @staticmethod
158
+ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None:
159
+ """Unfreezes the parameters of the provided modules.
160
+
161
+ Args:
162
+ modules: A given module or an iterable of modules
163
+ """
164
+ modules = BaseFinetuning.flatten_modules(modules)
165
+ for module in modules:
166
+ # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
167
+ for param in module.parameters(recurse=False):
168
+ param.requires_grad = True
169
+
170
+ @staticmethod
171
+ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None:
172
+ """Freezes the parameters of the provided modules.
173
+
174
+ Args:
175
+ modules: A given module or an iterable of modules
176
+ train_bn: If True, leave the BatchNorm layers in training mode
177
+
178
+ Returns:
179
+ None
180
+ """
181
+ modules = BaseFinetuning.flatten_modules(modules)
182
+ for mod in modules:
183
+ if isinstance(mod, _BatchNorm) and train_bn:
184
+ BaseFinetuning.make_trainable(mod)
185
+ else:
186
+ # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
187
+ for param in mod.parameters(recurse=False):
188
+ param.requires_grad = False
189
+
190
+ @staticmethod
191
+ def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
192
+ """This function is used to exclude any parameter which already exists in this optimizer.
193
+
194
+ Args:
195
+ optimizer: Optimizer used for parameter exclusion
196
+ params: Iterable of parameters used to check against the provided optimizer
197
+
198
+ Returns:
199
+ List of parameters not contained in this optimizer param groups
200
+ """
201
+ out_params = []
202
+ removed_params = []
203
+ for param in params:
204
+ if not any(torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]):
205
+ out_params.append(param)
206
+ else:
207
+ removed_params.append(param)
208
+
209
+ if removed_params:
210
+ rank_zero_warn(
211
+ "The provided params to be frozen already exist within another group of this optimizer."
212
+ " Those parameters will be skipped.\n"
213
+ "HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
214
+ f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",
215
+ )
216
+ return out_params
217
+
218
+ @staticmethod
219
+ def unfreeze_and_add_param_group(
220
+ modules: Union[Module, Iterable[Union[Module, Iterable]]],
221
+ optimizer: Optimizer,
222
+ lr: Optional[float] = None,
223
+ initial_denom_lr: float = 10.0,
224
+ train_bn: bool = True,
225
+ ) -> None:
226
+ """Unfreezes a module and adds its parameters to an optimizer.
227
+
228
+ Args:
229
+ modules: A module or iterable of modules to unfreeze.
230
+ Their parameters will be added to an optimizer as a new param group.
231
+ optimizer: The provided optimizer will receive new parameters and will add them to
232
+ `add_param_group`
233
+ lr: Learning rate for the new param group.
234
+ initial_denom_lr: If no lr is provided, the learning from the first param group will be used
235
+ and divided by `initial_denom_lr`.
236
+ train_bn: Whether to train the BatchNormalization layers.
237
+ """
238
+ BaseFinetuning.make_trainable(modules)
239
+ params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr)
240
+ denom_lr = initial_denom_lr if lr is None else 1.0
241
+ params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True)
242
+ params = BaseFinetuning.filter_on_optimizer(optimizer, params)
243
+ if params:
244
+ optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
245
+
246
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
247
+ self.freeze_before_training(pl_module)
248
+
249
+ @staticmethod
250
+ def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]:
251
+ output = []
252
+ for g in param_groups:
253
+ # skip params to save memory
254
+ group_state = {k: v for k, v in g.items() if k != "params"}
255
+ group_state["params"] = [mapping[p] for p in g["params"]]
256
+ output.append(group_state)
257
+ return output
258
+
259
+ def _store(
260
+ self,
261
+ pl_module: "pl.LightningModule",
262
+ opt_idx: int,
263
+ num_param_groups: int,
264
+ current_param_groups: List[Dict[str, Any]],
265
+ ) -> None:
266
+ mapping = {p: n for n, p in pl_module.named_parameters()}
267
+ if opt_idx not in self._internal_optimizer_metadata:
268
+ self._internal_optimizer_metadata[opt_idx] = self._apply_mapping_to_param_groups(
269
+ current_param_groups, mapping
270
+ )
271
+ elif num_param_groups != len(current_param_groups):
272
+ # save new param_groups possibly created by the users.
273
+ self._internal_optimizer_metadata[opt_idx].extend(
274
+ self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
275
+ )
276
+
277
+ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
278
+ """Called when the epoch begins."""
279
+ # import is here to avoid circular imports
280
+ from pytorch_lightning.loops.utilities import _get_active_optimizers
281
+
282
+ for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
283
+ num_param_groups = len(optimizer.param_groups)
284
+ self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
285
+ current_param_groups = optimizer.param_groups
286
+ self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
287
+
288
+ def finetune_function(
289
+ self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
290
+ ) -> None:
291
+ """Override to add your unfreeze logic."""
292
+ raise NotImplementedError
293
+
294
+ def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
295
+ """Override to add your freeze logic."""
296
+ raise NotImplementedError
297
+
298
+
299
+ class BackboneFinetuning(BaseFinetuning):
300
+ r"""Finetune a backbone model based on a learning rate user-defined scheduling.
301
+
302
+ When the backbone learning rate reaches the current model learning rate
303
+ and ``should_align`` is set to True, it will align with it for the rest of the training.
304
+
305
+ Args:
306
+ unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
307
+ lambda_func: Scheduling function for increasing backbone learning rate.
308
+ backbone_initial_ratio_lr:
309
+ Used to scale down the backbone learning rate compared to rest of model
310
+ backbone_initial_lr: Optional, Initial learning rate for the backbone.
311
+ By default, we will use ``current_learning / backbone_initial_ratio_lr``
312
+ should_align: Whether to align with current learning rate when backbone learning
313
+ reaches it.
314
+ initial_denom_lr: When unfreezing the backbone, the initial learning rate will
315
+ ``current_learning_rate / initial_denom_lr``.
316
+ train_bn: Whether to make Batch Normalization trainable.
317
+ verbose: Display current learning rate for model and backbone
318
+ rounding: Precision for displaying learning rate
319
+
320
+ Example::
321
+
322
+ >>> from pytorch_lightning import Trainer
323
+ >>> from pytorch_lightning.callbacks import BackboneFinetuning
324
+ >>> multiplicative = lambda epoch: 1.5
325
+ >>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
326
+ >>> trainer = Trainer(callbacks=[backbone_finetuning])
327
+
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ unfreeze_backbone_at_epoch: int = 10,
333
+ lambda_func: Callable = multiplicative,
334
+ backbone_initial_ratio_lr: float = 10e-2,
335
+ backbone_initial_lr: Optional[float] = None,
336
+ should_align: bool = True,
337
+ initial_denom_lr: float = 10.0,
338
+ train_bn: bool = True,
339
+ verbose: bool = False,
340
+ rounding: int = 12,
341
+ ) -> None:
342
+ super().__init__()
343
+
344
+ self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch
345
+ self.lambda_func: Callable = lambda_func
346
+ self.backbone_initial_ratio_lr: float = backbone_initial_ratio_lr
347
+ self.backbone_initial_lr: Optional[float] = backbone_initial_lr
348
+ self.should_align: bool = should_align
349
+ self.initial_denom_lr: float = initial_denom_lr
350
+ self.train_bn: bool = train_bn
351
+ self.verbose: bool = verbose
352
+ self.rounding: int = rounding
353
+ self.previous_backbone_lr: Optional[float] = None
354
+
355
+ def state_dict(self) -> Dict[str, Any]:
356
+ return {
357
+ "internal_optimizer_metadata": self._internal_optimizer_metadata,
358
+ "previous_backbone_lr": self.previous_backbone_lr,
359
+ }
360
+
361
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
362
+ self.previous_backbone_lr = state_dict["previous_backbone_lr"]
363
+ super().load_state_dict(state_dict)
364
+
365
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
366
+ """
367
+ Raises:
368
+ MisconfigurationException:
369
+ If LightningModule has no nn.Module `backbone` attribute.
370
+ """
371
+ if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
372
+ return super().on_fit_start(trainer, pl_module)
373
+ raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
374
+
375
+ def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
376
+ self.freeze(pl_module.backbone)
377
+
378
+ def finetune_function(
379
+ self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
380
+ ) -> None:
381
+ """Called when the epoch begins."""
382
+ if epoch == self.unfreeze_backbone_at_epoch:
383
+ current_lr = optimizer.param_groups[0]["lr"]
384
+ initial_backbone_lr = (
385
+ self.backbone_initial_lr
386
+ if self.backbone_initial_lr is not None
387
+ else current_lr * self.backbone_initial_ratio_lr
388
+ )
389
+ self.previous_backbone_lr = initial_backbone_lr
390
+ self.unfreeze_and_add_param_group(
391
+ pl_module.backbone,
392
+ optimizer,
393
+ initial_backbone_lr,
394
+ train_bn=self.train_bn,
395
+ initial_denom_lr=self.initial_denom_lr,
396
+ )
397
+ if self.verbose:
398
+ log.info(
399
+ f"Current lr: {round(current_lr, self.rounding)}, "
400
+ f"Backbone lr: {round(initial_backbone_lr, self.rounding)}"
401
+ )
402
+
403
+ elif epoch > self.unfreeze_backbone_at_epoch:
404
+ current_lr = optimizer.param_groups[0]["lr"]
405
+ next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr
406
+ next_current_backbone_lr = (
407
+ current_lr
408
+ if (self.should_align and next_current_backbone_lr > current_lr)
409
+ else next_current_backbone_lr
410
+ )
411
+ optimizer.param_groups[-1]["lr"] = next_current_backbone_lr
412
+ self.previous_backbone_lr = next_current_backbone_lr
413
+ if self.verbose:
414
+ log.info(
415
+ f"Current lr: {round(current_lr, self.rounding)}, "
416
+ f"Backbone lr: {round(next_current_backbone_lr, self.rounding)}"
417
+ )
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/gpu_stats_monitor.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ GPU Stats Monitor
16
+ =================
17
+
18
+ Monitor and logs GPU stats during training.
19
+
20
+ """
21
+
22
+ import os
23
+ import shutil
24
+ import subprocess
25
+ import time
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+ import torch
29
+
30
+ import pytorch_lightning as pl
31
+ from pytorch_lightning.callbacks.base import Callback
32
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
33
+ from pytorch_lightning.utilities.parsing import AttributeDict
34
+ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
35
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
36
+
37
+
38
+ class GPUStatsMonitor(Callback):
39
+ r"""
40
+ .. deprecated:: v1.5
41
+ The `GPUStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7.
42
+ Please use the `DeviceStatsMonitor` callback instead.
43
+
44
+ Automatically monitors and logs GPU stats during training stage. ``GPUStatsMonitor``
45
+ is a callback and in order to use it you need to assign a logger in the ``Trainer``.
46
+
47
+ Args:
48
+ memory_utilization: Set to ``True`` to monitor used, free and percentage of memory
49
+ utilization at the start and end of each step. Default: ``True``.
50
+ gpu_utilization: Set to ``True`` to monitor percentage of GPU utilization
51
+ at the start and end of each step. Default: ``True``.
52
+ intra_step_time: Set to ``True`` to monitor the time of each step. Default: ``False``.
53
+ inter_step_time: Set to ``True`` to monitor the time between the end of one step
54
+ and the start of the next step. Default: ``False``.
55
+ fan_speed: Set to ``True`` to monitor percentage of fan speed. Default: ``False``.
56
+ temperature: Set to ``True`` to monitor the memory and gpu temperature in degree Celsius.
57
+ Default: ``False``.
58
+
59
+ Raises:
60
+ MisconfigurationException:
61
+ If NVIDIA driver is not installed, not running on GPUs, or ``Trainer`` has no logger.
62
+
63
+ Example::
64
+
65
+ >>> from pytorch_lightning import Trainer
66
+ >>> from pytorch_lightning.callbacks import GPUStatsMonitor
67
+ >>> gpu_stats = GPUStatsMonitor() # doctest: +SKIP
68
+ >>> trainer = Trainer(callbacks=[gpu_stats]) # doctest: +SKIP
69
+
70
+ GPU stats are mainly based on `nvidia-smi --query-gpu` command. The description of the queries is as follows:
71
+
72
+ - **fan.speed** – The fan speed value is the percent of maximum speed that the device's fan is currently
73
+ intended to run at. It ranges from 0 to 100 %. Note: The reported speed is the intended fan speed.
74
+ If the fan is physically blocked and unable to spin, this output will not match the actual fan speed.
75
+ Many parts do not report fan speeds because they rely on cooling via fans in the surrounding enclosure.
76
+ - **memory.used** – Total memory allocated by active contexts.
77
+ - **memory.free** – Total free memory.
78
+ - **utilization.gpu** – Percent of time over the past sample period during which one or more kernels was
79
+ executing on the GPU. The sample period may be between 1 second and 1/6 second depending on the product.
80
+ - **utilization.memory** – Percent of time over the past sample period during which global (device) memory was
81
+ being read or written. The sample period may be between 1 second and 1/6 second depending on the product.
82
+ - **temperature.gpu** – Core GPU temperature, in degrees C.
83
+ - **temperature.memory** – HBM memory temperature, in degrees C.
84
+
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ memory_utilization: bool = True,
90
+ gpu_utilization: bool = True,
91
+ intra_step_time: bool = False,
92
+ inter_step_time: bool = False,
93
+ fan_speed: bool = False,
94
+ temperature: bool = False,
95
+ ):
96
+ super().__init__()
97
+
98
+ rank_zero_deprecation(
99
+ "The `GPUStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7."
100
+ " Please use the `DeviceStatsMonitor` callback instead."
101
+ )
102
+
103
+ if shutil.which("nvidia-smi") is None:
104
+ raise MisconfigurationException(
105
+ "Cannot use GPUStatsMonitor callback because NVIDIA driver is not installed."
106
+ )
107
+
108
+ self._log_stats = AttributeDict(
109
+ {
110
+ "memory_utilization": memory_utilization,
111
+ "gpu_utilization": gpu_utilization,
112
+ "intra_step_time": intra_step_time,
113
+ "inter_step_time": inter_step_time,
114
+ "fan_speed": fan_speed,
115
+ "temperature": temperature,
116
+ }
117
+ )
118
+
119
+ # The logical device IDs for selected devices
120
+ self._device_ids: List[int] = [] # will be assigned later in setup()
121
+
122
+ # The unmasked real GPU IDs
123
+ self._gpu_ids: List[str] = [] # will be assigned later in setup()
124
+
125
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
126
+ if not trainer.loggers:
127
+ raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")
128
+
129
+ if trainer.strategy.root_device.type != "cuda":
130
+ raise MisconfigurationException(
131
+ "You are using GPUStatsMonitor but are not running on GPU."
132
+ f" The root device type is {trainer.strategy.root_device.type}."
133
+ )
134
+
135
+ # The logical device IDs for selected devices
136
+ self._device_ids = sorted(set(trainer.device_ids))
137
+
138
+ # The unmasked real GPU IDs
139
+ self._gpu_ids = self._get_gpu_ids(self._device_ids)
140
+
141
+ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
142
+ self._snap_intra_step_time: Optional[float] = None
143
+ self._snap_inter_step_time: Optional[float] = None
144
+
145
+ @rank_zero_only
146
+ def on_train_batch_start(
147
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
148
+ ) -> None:
149
+ if self._log_stats.intra_step_time:
150
+ self._snap_intra_step_time = time.time()
151
+
152
+ if not trainer._logger_connector.should_update_logs:
153
+ return
154
+
155
+ gpu_stat_keys = self._get_gpu_stat_keys()
156
+ gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
157
+ logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
158
+
159
+ if self._log_stats.inter_step_time and self._snap_inter_step_time:
160
+ # First log at beginning of second step
161
+ logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
162
+
163
+ for logger in trainer.loggers:
164
+ logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
165
+
166
+ @rank_zero_only
167
+ def on_train_batch_end(
168
+ self,
169
+ trainer: "pl.Trainer",
170
+ pl_module: "pl.LightningModule",
171
+ outputs: STEP_OUTPUT,
172
+ batch: Any,
173
+ batch_idx: int,
174
+ ) -> None:
175
+ if self._log_stats.inter_step_time:
176
+ self._snap_inter_step_time = time.time()
177
+
178
+ if not trainer._logger_connector.should_update_logs:
179
+ return
180
+
181
+ gpu_stat_keys = self._get_gpu_stat_keys() + self._get_gpu_device_stat_keys()
182
+ gpu_stats = self._get_gpu_stats([k for k, _ in gpu_stat_keys])
183
+ logs = self._parse_gpu_stats(self._device_ids, gpu_stats, gpu_stat_keys)
184
+
185
+ if self._log_stats.intra_step_time and self._snap_intra_step_time:
186
+ logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
187
+
188
+ for logger in trainer.loggers:
189
+ logger.log_metrics(logs, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
190
+
191
+ @staticmethod
192
+ def _get_gpu_ids(device_ids: List[int]) -> List[str]:
193
+ """Get the unmasked real GPU IDs."""
194
+ # All devices if `CUDA_VISIBLE_DEVICES` unset
195
+ default = ",".join(str(i) for i in range(torch.cuda.device_count()))
196
+ cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",")
197
+ return [cuda_visible_devices[device_id].strip() for device_id in device_ids]
198
+
199
+ def _get_gpu_stats(self, queries: List[str]) -> List[List[float]]:
200
+ if not queries:
201
+ return []
202
+
203
+ """Run nvidia-smi to get the gpu stats"""
204
+ gpu_query = ",".join(queries)
205
+ format = "csv,nounits,noheader"
206
+ gpu_ids = ",".join(self._gpu_ids)
207
+ result = subprocess.run(
208
+ [
209
+ # it's ok to suppress the warning here since we ensure nvidia-smi exists during init
210
+ shutil.which("nvidia-smi"), # type: ignore
211
+ f"--query-gpu={gpu_query}",
212
+ f"--format={format}",
213
+ f"--id={gpu_ids}",
214
+ ],
215
+ encoding="utf-8",
216
+ capture_output=True,
217
+ check=True,
218
+ )
219
+
220
+ def _to_float(x: str) -> float:
221
+ try:
222
+ return float(x)
223
+ except ValueError:
224
+ return 0.0
225
+
226
+ stats = [[_to_float(x) for x in s.split(", ")] for s in result.stdout.strip().split(os.linesep)]
227
+ return stats
228
+
229
+ @staticmethod
230
+ def _parse_gpu_stats(
231
+ device_ids: List[int], stats: List[List[float]], keys: List[Tuple[str, str]]
232
+ ) -> Dict[str, float]:
233
+ """Parse the gpu stats into a loggable dict."""
234
+ logs = {}
235
+ for i, device_id in enumerate(device_ids):
236
+ for j, (x, unit) in enumerate(keys):
237
+ logs[f"device_id: {device_id}/{x} ({unit})"] = stats[i][j]
238
+ return logs
239
+
240
+ def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]:
241
+ """Get the GPU stats keys."""
242
+ stat_keys = []
243
+
244
+ if self._log_stats.gpu_utilization:
245
+ stat_keys.append(("utilization.gpu", "%"))
246
+
247
+ if self._log_stats.memory_utilization:
248
+ stat_keys.extend([("memory.used", "MB"), ("memory.free", "MB"), ("utilization.memory", "%")])
249
+
250
+ return stat_keys
251
+
252
+ def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]:
253
+ """Get the device stats keys."""
254
+ stat_keys = []
255
+
256
+ if self._log_stats.fan_speed:
257
+ stat_keys.append(("fan.speed", "%"))
258
+
259
+ if self._log_stats.temperature:
260
+ stat_keys.extend([("temperature.gpu", "°C"), ("temperature.memory", "°C")])
261
+
262
+ return stat_keys
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lambda_function.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ Lambda Callback
16
+ ^^^^^^^^^^^^^^^
17
+
18
+ Create a simple callback on the fly using lambda functions.
19
+
20
+ """
21
+
22
+ from typing import Callable, Optional
23
+
24
+ from pytorch_lightning.callbacks.base import Callback
25
+
26
+
27
+ class LambdaCallback(Callback):
28
+ r"""
29
+ Create a simple callback on the fly using lambda functions.
30
+
31
+ Args:
32
+ **kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback`
33
+
34
+ Example::
35
+
36
+ >>> from pytorch_lightning import Trainer
37
+ >>> from pytorch_lightning.callbacks import LambdaCallback
38
+ >>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ on_before_accelerator_backend_setup: Optional[Callable] = None,
44
+ setup: Optional[Callable] = None,
45
+ on_configure_sharded_model: Optional[Callable] = None,
46
+ teardown: Optional[Callable] = None,
47
+ on_init_start: Optional[Callable] = None,
48
+ on_init_end: Optional[Callable] = None,
49
+ on_fit_start: Optional[Callable] = None,
50
+ on_fit_end: Optional[Callable] = None,
51
+ on_sanity_check_start: Optional[Callable] = None,
52
+ on_sanity_check_end: Optional[Callable] = None,
53
+ on_train_batch_start: Optional[Callable] = None,
54
+ on_train_batch_end: Optional[Callable] = None,
55
+ on_train_epoch_start: Optional[Callable] = None,
56
+ on_train_epoch_end: Optional[Callable] = None,
57
+ on_validation_epoch_start: Optional[Callable] = None,
58
+ on_validation_epoch_end: Optional[Callable] = None,
59
+ on_test_epoch_start: Optional[Callable] = None,
60
+ on_test_epoch_end: Optional[Callable] = None,
61
+ on_epoch_start: Optional[Callable] = None,
62
+ on_epoch_end: Optional[Callable] = None,
63
+ on_batch_start: Optional[Callable] = None,
64
+ on_validation_batch_start: Optional[Callable] = None,
65
+ on_validation_batch_end: Optional[Callable] = None,
66
+ on_test_batch_start: Optional[Callable] = None,
67
+ on_test_batch_end: Optional[Callable] = None,
68
+ on_batch_end: Optional[Callable] = None,
69
+ on_train_start: Optional[Callable] = None,
70
+ on_train_end: Optional[Callable] = None,
71
+ on_pretrain_routine_start: Optional[Callable] = None,
72
+ on_pretrain_routine_end: Optional[Callable] = None,
73
+ on_validation_start: Optional[Callable] = None,
74
+ on_validation_end: Optional[Callable] = None,
75
+ on_test_start: Optional[Callable] = None,
76
+ on_test_end: Optional[Callable] = None,
77
+ on_keyboard_interrupt: Optional[Callable] = None,
78
+ on_exception: Optional[Callable] = None,
79
+ on_save_checkpoint: Optional[Callable] = None,
80
+ on_load_checkpoint: Optional[Callable] = None,
81
+ on_before_backward: Optional[Callable] = None,
82
+ on_after_backward: Optional[Callable] = None,
83
+ on_before_optimizer_step: Optional[Callable] = None,
84
+ on_before_zero_grad: Optional[Callable] = None,
85
+ on_predict_start: Optional[Callable] = None,
86
+ on_predict_end: Optional[Callable] = None,
87
+ on_predict_batch_start: Optional[Callable] = None,
88
+ on_predict_batch_end: Optional[Callable] = None,
89
+ on_predict_epoch_start: Optional[Callable] = None,
90
+ on_predict_epoch_end: Optional[Callable] = None,
91
+ ):
92
+ for k, v in locals().items():
93
+ if k == "self":
94
+ continue
95
+ if v is not None:
96
+ setattr(self, k, v)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/lr_monitor.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+
16
+ Learning Rate Monitor
17
+ =====================
18
+
19
+ Monitor and logs learning rate for lr schedulers during training.
20
+
21
+ """
22
+ import itertools
23
+ from collections import defaultdict
24
+ from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type
25
+
26
+ from torch.optim.optimizer import Optimizer
27
+
28
+ import pytorch_lightning as pl
29
+ from pytorch_lightning.callbacks.base import Callback
30
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
31
+ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
32
+ from pytorch_lightning.utilities.types import LRSchedulerConfig
33
+
34
+
35
+ class LearningRateMonitor(Callback):
36
+ r"""
37
+ Automatically monitor and logs learning rate for learning rate schedulers during training.
38
+
39
+ Args:
40
+ logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers
41
+ at the same interval, set to ``None`` to log at individual interval
42
+ according to the ``interval`` key of each scheduler. Defaults to ``None``.
43
+ log_momentum: option to also log the momentum values of the optimizer, if the optimizer
44
+ has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
45
+
46
+ Raises:
47
+ MisconfigurationException:
48
+ If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``.
49
+
50
+ Example::
51
+
52
+ >>> from pytorch_lightning import Trainer
53
+ >>> from pytorch_lightning.callbacks import LearningRateMonitor
54
+ >>> lr_monitor = LearningRateMonitor(logging_interval='step')
55
+ >>> trainer = Trainer(callbacks=[lr_monitor])
56
+
57
+ Logging names are automatically determined based on optimizer class name.
58
+ In case of multiple optimizers of same type, they will be named ``Adam``,
59
+ ``Adam-1`` etc. If a optimizer has multiple parameter groups they will
60
+ be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
61
+ ``name`` keyword in the construction of the learning rate schedulers.
62
+ A ``name`` keyword can also be used for parameter groups in the
63
+ construction of the optimizer.
64
+
65
+ Example::
66
+
67
+ def configure_optimizer(self):
68
+ optimizer = torch.optim.Adam(...)
69
+ lr_scheduler = {
70
+ 'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
71
+ 'name': 'my_logging_name'
72
+ }
73
+ return [optimizer], [lr_scheduler]
74
+
75
+ Example::
76
+
77
+ def configure_optimizer(self):
78
+ optimizer = torch.optim.SGD(
79
+ [{
80
+ 'params': [p for p in self.parameters()],
81
+ 'name': 'my_parameter_group_name'
82
+ }],
83
+ lr=0.1
84
+ )
85
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
86
+ return [optimizer], [lr_scheduler]
87
+
88
+ """
89
+
90
+ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None:
91
+ if logging_interval not in (None, "step", "epoch"):
92
+ raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.")
93
+
94
+ self.logging_interval = logging_interval
95
+ self.log_momentum = log_momentum
96
+ self.lrs: Dict[str, List[float]] = {}
97
+ self._lr_sch_names: List[str] = []
98
+
99
+ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
100
+ """Called before training, determines unique names for all lr schedulers in the case of multiple of the
101
+ same type or in the case of multiple parameter groups.
102
+
103
+ Raises:
104
+ MisconfigurationException:
105
+ If ``Trainer`` has no ``logger``.
106
+ """
107
+ if not trainer.loggers:
108
+ raise MisconfigurationException(
109
+ "Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
110
+ )
111
+
112
+ if self.log_momentum:
113
+
114
+ def _check_no_key(key: str) -> bool:
115
+ if trainer.lr_scheduler_configs:
116
+ return any(
117
+ key not in config.scheduler.optimizer.defaults for config in trainer.lr_scheduler_configs
118
+ )
119
+
120
+ return any(key not in optimizer.defaults for optimizer in trainer.optimizers)
121
+
122
+ if _check_no_key("momentum") and _check_no_key("betas"):
123
+ rank_zero_warn(
124
+ "You have set log_momentum=True, but some optimizers do not"
125
+ " have momentum. This will log a value 0 for the momentum.",
126
+ category=RuntimeWarning,
127
+ )
128
+
129
+ # Find names for schedulers
130
+ names: List[List[str]] = []
131
+ (
132
+ sched_hparam_keys,
133
+ optimizers_with_scheduler,
134
+ optimizers_with_scheduler_types,
135
+ ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs)
136
+ names.extend(sched_hparam_keys)
137
+
138
+ # Find names for leftover optimizers
139
+ optimizer_hparam_keys, _ = self._find_names_from_optimizers(
140
+ trainer.optimizers,
141
+ seen_optimizers=optimizers_with_scheduler,
142
+ seen_optimizer_types=optimizers_with_scheduler_types,
143
+ )
144
+ names.extend(optimizer_hparam_keys)
145
+
146
+ # Initialize for storing values
147
+ names_flatten = list(itertools.chain.from_iterable(names))
148
+ self.lrs = {name: [] for name in names_flatten}
149
+ self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
150
+
151
+ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
152
+ if not trainer._logger_connector.should_update_logs:
153
+ return
154
+
155
+ if self.logging_interval != "epoch":
156
+ interval = "step" if self.logging_interval is None else "any"
157
+ latest_stat = self._extract_stats(trainer, interval)
158
+
159
+ if latest_stat:
160
+ for logger in trainer.loggers:
161
+ logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
162
+
163
+ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
164
+ if self.logging_interval != "step":
165
+ interval = "epoch" if self.logging_interval is None else "any"
166
+ latest_stat = self._extract_stats(trainer, interval)
167
+
168
+ if latest_stat:
169
+ for logger in trainer.loggers:
170
+ logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped)
171
+
172
+ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
173
+ latest_stat = {}
174
+
175
+ (
176
+ scheduler_hparam_keys,
177
+ optimizers_with_scheduler,
178
+ optimizers_with_scheduler_types,
179
+ ) = self._find_names_from_schedulers(trainer.lr_scheduler_configs, add_lr_sch_names=False)
180
+ self._remap_keys(scheduler_hparam_keys)
181
+
182
+ for name, config in zip(scheduler_hparam_keys, trainer.lr_scheduler_configs):
183
+ if interval in [config.interval, "any"]:
184
+ opt = config.scheduler.optimizer
185
+ current_stat = self._get_lr_momentum_stat(opt, name)
186
+ latest_stat.update(current_stat)
187
+
188
+ optimizer_hparam_keys, optimizers_without_scheduler = self._find_names_from_optimizers(
189
+ trainer.optimizers,
190
+ seen_optimizers=optimizers_with_scheduler,
191
+ seen_optimizer_types=optimizers_with_scheduler_types,
192
+ add_lr_sch_names=False,
193
+ )
194
+ self._remap_keys(optimizer_hparam_keys)
195
+
196
+ for opt, names in zip(optimizers_without_scheduler, optimizer_hparam_keys):
197
+ current_stat = self._get_lr_momentum_stat(opt, names)
198
+ latest_stat.update(current_stat)
199
+
200
+ return latest_stat
201
+
202
+ def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]:
203
+ lr_momentum_stat = {}
204
+ param_groups = optimizer.param_groups
205
+ use_betas = "betas" in optimizer.defaults
206
+
207
+ for pg, name in zip(param_groups, names):
208
+ lr = self._extract_lr(pg, name)
209
+ lr_momentum_stat.update(lr)
210
+ momentum = self._extract_momentum(
211
+ param_group=pg, name=name.replace(name, f"{name}-momentum"), use_betas=use_betas
212
+ )
213
+ lr_momentum_stat.update(momentum)
214
+
215
+ return lr_momentum_stat
216
+
217
+ def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
218
+ lr = param_group["lr"]
219
+ self.lrs[name].append(lr)
220
+ return {name: lr}
221
+
222
+ def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None:
223
+ """This function is used the remap the keys if param groups for a given optimizer increased."""
224
+ for group_new_names in names:
225
+ for new_name in group_new_names:
226
+ old_name = new_name.replace(token, "")
227
+ if token in new_name and old_name in self.lrs:
228
+ self.lrs[new_name] = self.lrs.pop(old_name)
229
+ elif new_name not in self.lrs:
230
+ self.lrs[new_name] = []
231
+
232
+ def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]:
233
+ if not self.log_momentum:
234
+ return {}
235
+
236
+ momentum = param_group["betas"][0] if use_betas else param_group.get("momentum", 0)
237
+ self.last_momentum_values[name] = momentum
238
+ return {name: momentum}
239
+
240
+ def _add_prefix(
241
+ self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int]
242
+ ) -> str:
243
+ if optimizer_cls not in seen_optimizer_types:
244
+ return name
245
+ count = seen_optimizer_types[optimizer_cls]
246
+ return name + f"-{count - 1}" if count > 1 else name
247
+
248
+ def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str:
249
+ if len(param_groups) > 1:
250
+ if not use_names:
251
+ return f"{name}/pg{param_group_index+1}"
252
+ pg_name = param_groups[param_group_index].get("name", f"pg{param_group_index+1}")
253
+ return f"{name}/{pg_name}"
254
+ elif use_names:
255
+ pg_name = param_groups[param_group_index].get("name")
256
+ return f"{name}/{pg_name}" if pg_name else name
257
+ return name
258
+
259
+ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
260
+ names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)]
261
+ unique = set(names)
262
+ if len(names) == len(unique):
263
+ return set()
264
+ return {n for n in names if names.count(n) > 1}
265
+
266
+ def _find_names_from_schedulers(
267
+ self, lr_scheduler_configs: List[LRSchedulerConfig], add_lr_sch_names: bool = True
268
+ ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]:
269
+ # Create unique names in the case we have multiple of the same learning
270
+ # rate scheduler + multiple parameter groups
271
+ names = []
272
+ seen_optimizers: List[Optimizer] = []
273
+ seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int)
274
+ for config in lr_scheduler_configs:
275
+ sch = config.scheduler
276
+ if config.name is not None:
277
+ name = config.name
278
+ else:
279
+ name = "lr-" + sch.optimizer.__class__.__name__
280
+
281
+ updated_names = self._check_duplicates_and_update_name(
282
+ sch.optimizer, name, seen_optimizers, seen_optimizer_types, config, add_lr_sch_names
283
+ )
284
+ names.append(updated_names)
285
+
286
+ return names, seen_optimizers, seen_optimizer_types
287
+
288
+ def _find_names_from_optimizers(
289
+ self,
290
+ optimizers: List[Any],
291
+ seen_optimizers: List[Optimizer],
292
+ seen_optimizer_types: DefaultDict[Type[Optimizer], int],
293
+ add_lr_sch_names: bool = True,
294
+ ) -> Tuple[List[List[str]], List[Optimizer]]:
295
+ names = []
296
+ optimizers_without_scheduler = []
297
+
298
+ for optimizer in optimizers:
299
+ # Deepspeed optimizer wraps the native optimizer
300
+ optimizer = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
301
+ if optimizer in seen_optimizers:
302
+ continue
303
+
304
+ name = "lr-" + optimizer.__class__.__name__
305
+ updated_names = self._check_duplicates_and_update_name(
306
+ optimizer, name, seen_optimizers, seen_optimizer_types, None, add_lr_sch_names
307
+ )
308
+ names.append(updated_names)
309
+ optimizers_without_scheduler.append(optimizer)
310
+
311
+ return names, optimizers_without_scheduler
312
+
313
+ def _check_duplicates_and_update_name(
314
+ self,
315
+ optimizer: Optimizer,
316
+ name: str,
317
+ seen_optimizers: List[Optimizer],
318
+ seen_optimizer_types: DefaultDict[Type[Optimizer], int],
319
+ lr_scheduler_config: Optional[LRSchedulerConfig],
320
+ add_lr_sch_names: bool = True,
321
+ ) -> List[str]:
322
+ seen_optimizers.append(optimizer)
323
+ optimizer_cls = type(optimizer)
324
+ if lr_scheduler_config is not None and lr_scheduler_config.name is None:
325
+ seen_optimizer_types[optimizer_cls] += 1
326
+ elif lr_scheduler_config is None:
327
+ seen_optimizer_types[optimizer_cls] += 1
328
+
329
+ # Multiple param groups for the same optimizer
330
+ param_groups = optimizer.param_groups
331
+ duplicates = self._duplicate_param_group_names(param_groups)
332
+ if duplicates:
333
+ raise MisconfigurationException(
334
+ "A single `Optimizer` cannot have multiple parameter groups with identical "
335
+ f"`name` values. {name} has duplicated parameter group names {duplicates}"
336
+ )
337
+
338
+ name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
339
+ name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]
340
+
341
+ if add_lr_sch_names:
342
+ self._lr_sch_names.append(name)
343
+
344
+ return name_list
345
+
346
+ @property
347
+ def lr_sch_names(self) -> List[str]:
348
+ # TODO remove `lr_sch_names` and `add_lr_sch_names` argument in v1.7.0
349
+ rank_zero_deprecation(
350
+ "`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5 and will be removed in 1.7."
351
+ " Consider accessing them using `LearningRateMonitor.lrs.keys()` which will return"
352
+ " the names of all the optimizers, even those without a scheduler."
353
+ )
354
+ return self._lr_sch_names
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Model Checkpointing
16
+ ===================
17
+
18
+ Automatically save model checkpoints during training.
19
+
20
+ """
21
+ import logging
22
+ import os
23
+ import re
24
+ import time
25
+ import warnings
26
+ from copy import deepcopy
27
+ from datetime import timedelta
28
+ from typing import Any, Dict, Optional
29
+ from weakref import proxy
30
+
31
+ import numpy as np
32
+ import torch
33
+ import yaml
34
+
35
+ import pytorch_lightning as pl
36
+ from pytorch_lightning.callbacks.base import Callback
37
+ from pytorch_lightning.utilities.cloud_io import get_filesystem
38
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
39
+ from pytorch_lightning.utilities.logger import _name, _version
40
+ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
41
+ from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
42
+ from pytorch_lightning.utilities.warnings import WarningCache
43
+
44
+ log = logging.getLogger(__name__)
45
+ warning_cache = WarningCache()
46
+
47
+
48
+ class ModelCheckpoint(Callback):
49
+ r"""
50
+ Save the model periodically by monitoring a quantity. Every metric logged with
51
+ :meth:`~pytorch_lightning.core.lightning.log` or :meth:`~pytorch_lightning.core.lightning.log_dict` in
52
+ LightningModule is a candidate for the monitor key. For more information, see
53
+ :ref:`checkpointing`.
54
+
55
+ After training finishes, use :attr:`best_model_path` to retrieve the path to the
56
+ best checkpoint file and :attr:`best_model_score` to retrieve its score.
57
+
58
+ Args:
59
+ dirpath: directory to save the model file.
60
+
61
+ Example::
62
+
63
+ # custom path
64
+ # saves a file like: my/path/epoch=0-step=10.ckpt
65
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
66
+
67
+ By default, dirpath is ``None`` and will be set at runtime to the location
68
+ specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s
69
+ :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or
70
+ :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments,
71
+ and if the Trainer uses a logger, the path will also contain logger name and version.
72
+
73
+ filename: checkpoint filename. Can contain named formatting options to be auto-filled.
74
+
75
+ Example::
76
+
77
+ # save any arbitrary metrics like `val_loss`, etc. in name
78
+ # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
79
+ >>> checkpoint_callback = ModelCheckpoint(
80
+ ... dirpath='my/path',
81
+ ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
82
+ ... )
83
+
84
+ By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
85
+ monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch.
86
+ verbose: verbosity mode. Default: ``False``.
87
+ save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint
88
+ file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``.
89
+ save_top_k: if ``save_top_k == k``,
90
+ the best k models according to
91
+ the quantity monitored will be saved.
92
+ if ``save_top_k == 0``, no models are saved.
93
+ if ``save_top_k == -1``, all models are saved.
94
+ Please note that the monitors are checked every ``every_n_epochs`` epochs.
95
+ if ``save_top_k >= 2`` and the callback is called multiple
96
+ times inside an epoch, the name of the saved file will be
97
+ appended with a version count starting with ``v1``.
98
+ mode: one of {min, max}.
99
+ If ``save_top_k != 0``, the decision to overwrite the current save file is made
100
+ based on either the maximization or the minimization of the monitored quantity.
101
+ For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc.
102
+ auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name.
103
+ For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve
104
+ to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/``
105
+ as this will result in extra folders.
106
+ save_weights_only: if ``True``, then only the model's weights will be
107
+ saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
108
+ every_n_train_steps: Number of training steps between checkpoints.
109
+ If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training.
110
+ To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative.
111
+ This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``.
112
+ train_time_interval: Checkpoints are monitored at the specified time interval.
113
+ For all practical purposes, this cannot be smaller than the amount
114
+ of time it takes to process a single training batch. This is not
115
+ guaranteed to execute at the exact time specified, but should be close.
116
+ This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
117
+ every_n_epochs: Number of epochs between checkpoints.
118
+ This value must be ``None`` or non-negative.
119
+ To disable saving top-k checkpoints, set ``every_n_epochs = 0``.
120
+ This argument does not impact the saving of ``save_last=True`` checkpoints.
121
+ If all of ``every_n_epochs``, ``every_n_train_steps`` and
122
+ ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch
123
+ (equivalent to ``every_n_epochs = 1``).
124
+ If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``,
125
+ saving at the end of each epoch is disabled
126
+ (equivalent to ``every_n_epochs = 0``).
127
+ This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``.
128
+ Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and
129
+ ``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
130
+ will only save checkpoints at epochs 0 < E <= N
131
+ where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
132
+ save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
133
+ If this is ``False``, then the check runs at the end of the validation.
134
+
135
+ Note:
136
+ For extra customization, ModelCheckpoint includes the following attributes:
137
+
138
+ - ``CHECKPOINT_JOIN_CHAR = "-"``
139
+ - ``CHECKPOINT_NAME_LAST = "last"``
140
+ - ``FILE_EXTENSION = ".ckpt"``
141
+ - ``STARTING_VERSION = 1``
142
+
143
+ For example, you can change the default last checkpoint name by doing
144
+ ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"``
145
+
146
+ If you want to checkpoint every N hours, every M train batches, and/or every K val epochs,
147
+ then you should create multiple ``ModelCheckpoint`` callbacks.
148
+
149
+ If the checkpoint's ``dirpath`` changed from what it was before while resuming the training,
150
+ only ``best_model_path`` will be reloaded and a warning will be issued.
151
+
152
+ Raises:
153
+ MisconfigurationException:
154
+ If ``save_top_k`` is smaller than ``-1``,
155
+ if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or
156
+ if ``mode`` is none of ``"min"`` or ``"max"``.
157
+ ValueError:
158
+ If ``trainer.save_checkpoint`` is ``None``.
159
+
160
+ Example::
161
+
162
+ >>> from pytorch_lightning import Trainer
163
+ >>> from pytorch_lightning.callbacks import ModelCheckpoint
164
+
165
+ # saves checkpoints to 'my/path/' at every epoch
166
+ >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
167
+ >>> trainer = Trainer(callbacks=[checkpoint_callback])
168
+
169
+ # save epoch and val_loss in name
170
+ # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
171
+ >>> checkpoint_callback = ModelCheckpoint(
172
+ ... monitor='val_loss',
173
+ ... dirpath='my/path/',
174
+ ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
175
+ ... )
176
+
177
+ # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
178
+ # or Neptune, due to the presence of characters like '=' or '/')
179
+ # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
180
+ >>> checkpoint_callback = ModelCheckpoint(
181
+ ... monitor='val/loss',
182
+ ... dirpath='my/path/',
183
+ ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
184
+ ... auto_insert_metric_name=False
185
+ ... )
186
+
187
+ # retrieve the best checkpoint after training
188
+ checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
189
+ trainer = Trainer(callbacks=[checkpoint_callback])
190
+ model = ...
191
+ trainer.fit(model)
192
+ checkpoint_callback.best_model_path
193
+
194
+ .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
195
+ following arguments:
196
+
197
+ *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end*
198
+
199
+ Read more: :ref:`Persisting Callback State`
200
+ """
201
+
202
+ CHECKPOINT_JOIN_CHAR = "-"
203
+ CHECKPOINT_NAME_LAST = "last"
204
+ FILE_EXTENSION = ".ckpt"
205
+ STARTING_VERSION = 1
206
+
207
+ def __init__(
208
+ self,
209
+ dirpath: Optional[_PATH] = None,
210
+ filename: Optional[str] = None,
211
+ monitor: Optional[str] = None,
212
+ verbose: bool = False,
213
+ save_last: Optional[bool] = None,
214
+ save_top_k: int = 1,
215
+ save_weights_only: bool = False,
216
+ mode: str = "min",
217
+ auto_insert_metric_name: bool = True,
218
+ every_n_train_steps: Optional[int] = None,
219
+ train_time_interval: Optional[timedelta] = None,
220
+ every_n_epochs: Optional[int] = None,
221
+ save_on_train_epoch_end: Optional[bool] = None,
222
+ ):
223
+ super().__init__()
224
+ self.monitor = monitor
225
+ self.verbose = verbose
226
+ self.save_last = save_last
227
+ self.save_top_k = save_top_k
228
+ self.save_weights_only = save_weights_only
229
+ self.auto_insert_metric_name = auto_insert_metric_name
230
+ self._save_on_train_epoch_end = save_on_train_epoch_end
231
+ self._last_global_step_saved = 0 # no need to save when no steps were taken
232
+ self._last_time_checked: Optional[float] = None
233
+ self.current_score = None
234
+ self.best_k_models = {}
235
+ self.kth_best_model_path = ""
236
+ self.best_model_score = None
237
+ self.best_model_path = ""
238
+ self.last_model_path = ""
239
+
240
+ self.__init_monitor_mode(mode)
241
+ self.__init_ckpt_dir(dirpath, filename)
242
+ self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
243
+ self.__validate_init_configuration()
244
+
245
+ @property
246
+ def state_key(self) -> str:
247
+ return self._generate_state_key(
248
+ monitor=self.monitor,
249
+ mode=self.mode,
250
+ every_n_train_steps=self._every_n_train_steps,
251
+ every_n_epochs=self._every_n_epochs,
252
+ train_time_interval=self._train_time_interval,
253
+ save_on_train_epoch_end=self._save_on_train_epoch_end,
254
+ )
255
+
256
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
257
+ self.__resolve_ckpt_dir(trainer)
258
+ if trainer.is_global_zero and stage == "fit":
259
+ self.__warn_if_dir_not_empty(self.dirpath)
260
+
261
+ # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
262
+ # because the attributes are part of the state_key which needs to be fully defined before reloading.
263
+ if self._save_on_train_epoch_end is None:
264
+ # if the user runs validation multiple times per training epoch or multiple training epochs without
265
+ # validation, then we run after validation instead of on train epoch end
266
+ self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
267
+
268
+ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
269
+ self._last_time_checked = time.monotonic()
270
+
271
+ def on_train_batch_end(
272
+ self,
273
+ trainer: "pl.Trainer",
274
+ pl_module: "pl.LightningModule",
275
+ outputs: STEP_OUTPUT,
276
+ batch: Any,
277
+ batch_idx: int,
278
+ ) -> None:
279
+ """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
280
+ if self._should_skip_saving_checkpoint(trainer):
281
+ return
282
+ skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
283
+
284
+ train_time_interval = self._train_time_interval
285
+ skip_time = True
286
+ now = time.monotonic()
287
+ if train_time_interval:
288
+ prev_time_check = self._last_time_checked
289
+ skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
290
+ # in case we have time differences across ranks
291
+ # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
292
+ skip_time = trainer.strategy.broadcast(skip_time)
293
+
294
+ if skip_batch and skip_time:
295
+ return
296
+ if not skip_time:
297
+ self._last_time_checked = now
298
+
299
+ monitor_candidates = self._monitor_candidates(trainer)
300
+ self._save_topk_checkpoint(trainer, monitor_candidates)
301
+ self._save_last_checkpoint(trainer, monitor_candidates)
302
+
303
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
304
+ """Save a checkpoint at the end of the training epoch."""
305
+ if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
306
+ monitor_candidates = self._monitor_candidates(trainer)
307
+ if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
308
+ self._save_topk_checkpoint(trainer, monitor_candidates)
309
+ self._save_last_checkpoint(trainer, monitor_candidates)
310
+
311
+ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
312
+ """Save a checkpoint at the end of the validation stage."""
313
+ if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
314
+ monitor_candidates = self._monitor_candidates(trainer)
315
+ if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
316
+ self._save_topk_checkpoint(trainer, monitor_candidates)
317
+ self._save_last_checkpoint(trainer, monitor_candidates)
318
+
319
+ def state_dict(self) -> Dict[str, Any]:
320
+ return {
321
+ "monitor": self.monitor,
322
+ "best_model_score": self.best_model_score,
323
+ "best_model_path": self.best_model_path,
324
+ "current_score": self.current_score,
325
+ "dirpath": self.dirpath,
326
+ "best_k_models": self.best_k_models,
327
+ "kth_best_model_path": self.kth_best_model_path,
328
+ "kth_value": self.kth_value,
329
+ "last_model_path": self.last_model_path,
330
+ }
331
+
332
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
333
+ dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath)
334
+
335
+ if self.dirpath == dirpath_from_ckpt:
336
+ self.best_model_score = state_dict["best_model_score"]
337
+ self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path)
338
+ self.kth_value = state_dict.get("kth_value", self.kth_value)
339
+ self.best_k_models = state_dict.get("best_k_models", self.best_k_models)
340
+ self.last_model_path = state_dict.get("last_model_path", self.last_model_path)
341
+ else:
342
+ warnings.warn(
343
+ f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"
344
+ " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"
345
+ " `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded."
346
+ )
347
+
348
+ self.best_model_path = state_dict["best_model_path"]
349
+
350
+ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
351
+ """Performs the main logic around saving a checkpoint.
352
+
353
+ This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
354
+ behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
355
+ """
356
+ rank_zero_deprecation(
357
+ f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
358
+ " Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
359
+ )
360
+ monitor_candidates = self._monitor_candidates(trainer)
361
+ self._save_topk_checkpoint(trainer, monitor_candidates)
362
+ self._save_last_checkpoint(trainer, monitor_candidates)
363
+
364
+ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
365
+ if self.save_top_k == 0:
366
+ return
367
+
368
+ # validate metric
369
+ if self.monitor is not None:
370
+ if self.monitor not in monitor_candidates:
371
+ m = (
372
+ f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"
373
+ f" metrics: {list(monitor_candidates)}."
374
+ f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?"
375
+ )
376
+ if trainer.fit_loop.epoch_loop.val_loop._has_run:
377
+ raise MisconfigurationException(m)
378
+ warning_cache.warn(m)
379
+ self._save_monitor_checkpoint(trainer, monitor_candidates)
380
+ else:
381
+ self._save_none_monitor_checkpoint(trainer, monitor_candidates)
382
+
383
+ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
384
+ trainer.save_checkpoint(filepath, self.save_weights_only)
385
+
386
+ self._last_global_step_saved = trainer.global_step
387
+
388
+ # notify loggers
389
+ if trainer.is_global_zero:
390
+ for logger in trainer.loggers:
391
+ logger.after_save_checkpoint(proxy(self))
392
+
393
+ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
394
+ from pytorch_lightning.trainer.states import TrainerFn
395
+
396
+ return (
397
+ trainer.fast_dev_run # disable checkpointing with fast_dev_run
398
+ or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
399
+ or trainer.sanity_checking # don't save anything during sanity check
400
+ or self._last_global_step_saved == trainer.global_step # already saved at the last step
401
+ )
402
+
403
+ def __validate_init_configuration(self) -> None:
404
+ if self.save_top_k < -1:
405
+ raise MisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1")
406
+ if self._every_n_train_steps < 0:
407
+ raise MisconfigurationException(
408
+ f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0"
409
+ )
410
+ if self._every_n_epochs < 0:
411
+ raise MisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0")
412
+
413
+ every_n_train_steps_triggered = self._every_n_train_steps >= 1
414
+ every_n_epochs_triggered = self._every_n_epochs >= 1
415
+ train_time_interval_triggered = self._train_time_interval is not None
416
+ if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1:
417
+ raise MisconfigurationException(
418
+ f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
419
+ f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} "
420
+ "should be mutually exclusive."
421
+ )
422
+
423
+ if self.monitor is None:
424
+ # -1: save all epochs, 0: nothing is saved, 1: save last epoch
425
+ if self.save_top_k not in (-1, 0, 1):
426
+ raise MisconfigurationException(
427
+ f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"
428
+ " configuration. No quantity for top_k to track."
429
+ )
430
+
431
+ if self.save_top_k == -1 and self.save_last:
432
+ rank_zero_info(
433
+ "ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"
434
+ " will duplicate the last checkpoint saved."
435
+ )
436
+
437
+ def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
438
+ self._fs = get_filesystem(dirpath if dirpath else "")
439
+
440
+ if dirpath and self._fs.protocol == "file":
441
+ dirpath = os.path.realpath(dirpath)
442
+
443
+ self.dirpath = dirpath
444
+ self.filename = filename
445
+
446
+ def __init_monitor_mode(self, mode: str) -> None:
447
+ torch_inf = torch.tensor(np.Inf)
448
+ mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")}
449
+
450
+ if mode not in mode_dict:
451
+ raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}")
452
+
453
+ self.kth_value, self.mode = mode_dict[mode]
454
+
455
+ def __init_triggers(
456
+ self,
457
+ every_n_train_steps: Optional[int],
458
+ every_n_epochs: Optional[int],
459
+ train_time_interval: Optional[timedelta],
460
+ ) -> None:
461
+
462
+ # Default to running once after each validation epoch if neither
463
+ # every_n_train_steps nor every_n_epochs is set
464
+ if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None:
465
+ every_n_epochs = 1
466
+ every_n_train_steps = 0
467
+ log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1")
468
+ else:
469
+ every_n_epochs = every_n_epochs or 0
470
+ every_n_train_steps = every_n_train_steps or 0
471
+
472
+ self._train_time_interval: Optional[timedelta] = train_time_interval
473
+ self._every_n_epochs: int = every_n_epochs
474
+ self._every_n_train_steps: int = every_n_train_steps
475
+
476
+ @property
477
+ def every_n_epochs(self) -> Optional[int]:
478
+ return self._every_n_epochs
479
+
480
+ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Tensor] = None) -> bool:
481
+ if current is None:
482
+ return False
483
+
484
+ if self.save_top_k == -1:
485
+ return True
486
+
487
+ less_than_k_models = len(self.best_k_models) < self.save_top_k
488
+ if less_than_k_models:
489
+ return True
490
+
491
+ monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
492
+ should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
493
+
494
+ # If using multiple devices, make sure all processes are unanimous on the decision.
495
+ should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
496
+
497
+ return should_update_best_and_save
498
+
499
+ @classmethod
500
+ def _format_checkpoint_name(
501
+ cls,
502
+ filename: Optional[str],
503
+ metrics: Dict[str, _METRIC],
504
+ prefix: str = "",
505
+ auto_insert_metric_name: bool = True,
506
+ ) -> str:
507
+ if not filename:
508
+ # filename is not set, use default name
509
+ filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}"
510
+
511
+ # check and parse user passed keys in the string
512
+ groups = re.findall(r"(\{.*?)[:\}]", filename)
513
+ if len(groups) >= 0:
514
+ for group in groups:
515
+ name = group[1:]
516
+
517
+ if auto_insert_metric_name:
518
+ filename = filename.replace(group, name + "={" + name)
519
+
520
+ # support for dots: https://stackoverflow.com/a/7934969
521
+ filename = filename.replace(group, f"{{0[{name}]")
522
+
523
+ if name not in metrics:
524
+ metrics[name] = 0
525
+ filename = filename.format(metrics)
526
+
527
+ if prefix:
528
+ filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
529
+
530
+ return filename
531
+
532
+ def format_checkpoint_name(
533
+ self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
534
+ ) -> str:
535
+ """Generate a filename according to the defined template.
536
+
537
+ Example::
538
+
539
+ >>> tmpdir = os.path.dirname(__file__)
540
+ >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
541
+ >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0)))
542
+ 'epoch=0.ckpt'
543
+ >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
544
+ >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5)))
545
+ 'epoch=005.ckpt'
546
+ >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
547
+ >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
548
+ 'epoch=2-val_loss=0.12.ckpt'
549
+ >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}'))
550
+ 'epoch=2.ckpt'
551
+ >>> ckpt = ModelCheckpoint(dirpath=tmpdir,
552
+ ... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
553
+ ... auto_insert_metric_name=False)
554
+ >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
555
+ 'epoch=2-validation_loss=0.12.ckpt'
556
+ >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
557
+ >>> os.path.basename(ckpt.format_checkpoint_name({}))
558
+ 'missing=0.ckpt'
559
+ >>> ckpt = ModelCheckpoint(filename='{step}')
560
+ >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
561
+ 'step=0.ckpt'
562
+ """
563
+ filename = filename or self.filename
564
+ filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name)
565
+
566
+ if ver is not None:
567
+ filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
568
+
569
+ ckpt_name = f"{filename}{self.FILE_EXTENSION}"
570
+ return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
571
+
572
+ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
573
+ """Determines model checkpoint save directory at runtime. References attributes from the trainer's logger
574
+ to determine where to save checkpoints. The base path for saving weights is set in this priority:
575
+
576
+ 1. Checkpoint callback's path (if passed in)
577
+ 2. The default_root_dir from trainer if trainer has no logger
578
+ 3. The weights_save_path from trainer, if user provides it (deprecated)
579
+ 4. User provided weights_saved_path
580
+
581
+ The base path gets extended with logger name and version (if these are available)
582
+ and subfolder "checkpoints".
583
+ """
584
+ if self.dirpath is not None:
585
+ return # short circuit
586
+
587
+ # TODO: Remove weights_save_path logic here in v1.8
588
+ if trainer.loggers:
589
+ if trainer._weights_save_path_internal != trainer.default_root_dir:
590
+ # the user has changed weights_save_path, it overrides anything
591
+ save_dir = trainer._weights_save_path_internal
592
+ elif len(trainer.loggers) == 1:
593
+ save_dir = trainer.logger.save_dir or trainer.default_root_dir
594
+ else:
595
+ save_dir = trainer.default_root_dir
596
+
597
+ name = _name(trainer.loggers)
598
+ version = _version(trainer.loggers)
599
+ version = version if isinstance(version, str) else f"version_{version}"
600
+
601
+ ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
602
+ else:
603
+ ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
604
+
605
+ ckpt_path = trainer.strategy.broadcast(ckpt_path)
606
+
607
+ self.dirpath = ckpt_path
608
+
609
+ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
610
+ if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
611
+ rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
612
+
613
+ def _get_metric_interpolated_filepath_name(
614
+ self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
615
+ ) -> str:
616
+ filepath = self.format_checkpoint_name(monitor_candidates)
617
+
618
+ version_cnt = self.STARTING_VERSION
619
+ while self.file_exists(filepath, trainer) and filepath != del_filepath:
620
+ filepath = self.format_checkpoint_name(monitor_candidates, ver=version_cnt)
621
+ version_cnt += 1
622
+
623
+ return filepath
624
+
625
+ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
626
+ monitor_candidates = deepcopy(trainer.callback_metrics)
627
+ # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
628
+ # or does not exist we overwrite it as it's likely an error
629
+ epoch = monitor_candidates.get("epoch")
630
+ monitor_candidates["epoch"] = (
631
+ epoch.int() if isinstance(epoch, torch.Tensor) else torch.tensor(trainer.current_epoch)
632
+ )
633
+ step = monitor_candidates.get("step")
634
+ monitor_candidates["step"] = step.int() if isinstance(step, torch.Tensor) else torch.tensor(trainer.global_step)
635
+ return monitor_candidates
636
+
637
+ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
638
+ if not self.save_last:
639
+ return
640
+
641
+ filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
642
+ # set the last model path before saving because it will be part of the state.
643
+ previous, self.last_model_path = self.last_model_path, filepath
644
+ self._save_checkpoint(trainer, filepath)
645
+ if previous and previous != filepath:
646
+ trainer.strategy.remove_checkpoint(previous)
647
+
648
+ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
649
+ current = monitor_candidates.get(self.monitor)
650
+ if self.check_monitor_top_k(trainer, current):
651
+ self._update_best_and_save(current, trainer, monitor_candidates)
652
+ elif self.verbose:
653
+ epoch = monitor_candidates["epoch"]
654
+ step = monitor_candidates["step"]
655
+ rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
656
+
657
+ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
658
+ filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
659
+ # set the best model path before saving because it will be part of the state.
660
+ previous, self.best_model_path = self.best_model_path, filepath
661
+ self._save_checkpoint(trainer, filepath)
662
+ if self.save_top_k == 1 and previous and previous != filepath:
663
+ trainer.strategy.remove_checkpoint(previous)
664
+
665
+ def _update_best_and_save(
666
+ self, current: torch.Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
667
+ ) -> None:
668
+ k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
669
+
670
+ del_filepath = None
671
+ if len(self.best_k_models) == k and k > 0:
672
+ del_filepath = self.kth_best_model_path
673
+ self.best_k_models.pop(del_filepath)
674
+
675
+ # do not save nan, replace with +/- inf
676
+ if isinstance(current, torch.Tensor) and torch.isnan(current):
677
+ current = torch.tensor(float("inf" if self.mode == "min" else "-inf"), device=current.device)
678
+
679
+ filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath)
680
+
681
+ # save the current score
682
+ self.current_score = current
683
+ self.best_k_models[filepath] = current
684
+
685
+ if len(self.best_k_models) == k:
686
+ # monitor dict has reached k elements
687
+ _op = max if self.mode == "min" else min
688
+ self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
689
+ self.kth_value = self.best_k_models[self.kth_best_model_path]
690
+
691
+ _op = min if self.mode == "min" else max
692
+ self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
693
+ self.best_model_score = self.best_k_models[self.best_model_path]
694
+
695
+ if self.verbose:
696
+ epoch = monitor_candidates["epoch"]
697
+ step = monitor_candidates["step"]
698
+ rank_zero_info(
699
+ f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"
700
+ f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}"
701
+ )
702
+ self._save_checkpoint(trainer, filepath)
703
+
704
+ if del_filepath is not None and filepath != del_filepath:
705
+ trainer.strategy.remove_checkpoint(del_filepath)
706
+
707
+ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
708
+ """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
709
+ file."""
710
+ best_k = {k: v.item() for k, v in self.best_k_models.items()}
711
+ if filepath is None:
712
+ filepath = os.path.join(self.dirpath, "best_k_models.yaml")
713
+ with self._fs.open(filepath, "w") as fp:
714
+ yaml.dump(best_k, fp)
715
+
716
+ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
717
+ """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
718
+ state to diverge between ranks."""
719
+ exists = self._fs.exists(filepath)
720
+ return trainer.strategy.broadcast(exists)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_summary.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Model Summary
16
+ =============
17
+
18
+ Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
19
+
20
+ The string representation of this summary prints a table with columns containing
21
+ the name, type and number of parameters for each layer.
22
+
23
+ """
24
+ import logging
25
+ from typing import List, Tuple
26
+
27
+ import pytorch_lightning as pl
28
+ from pytorch_lightning.callbacks.base import Callback
29
+ from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+
34
+ class ModelSummary(Callback):
35
+ r"""
36
+ Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
37
+
38
+ Args:
39
+ max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
40
+ layer summary off.
41
+
42
+ Example::
43
+
44
+ >>> from pytorch_lightning import Trainer
45
+ >>> from pytorch_lightning.callbacks import ModelSummary
46
+ >>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
47
+ """
48
+
49
+ def __init__(self, max_depth: int = 1) -> None:
50
+ self._max_depth: int = max_depth
51
+
52
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
53
+ if not self._max_depth:
54
+ return None
55
+
56
+ model_summary = summarize(pl_module, max_depth=self._max_depth)
57
+ summary_data = model_summary._get_summary_data()
58
+ total_parameters = model_summary.total_parameters
59
+ trainable_parameters = model_summary.trainable_parameters
60
+ model_size = model_summary.model_size
61
+
62
+ if trainer.is_global_zero:
63
+ self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
64
+
65
+ @staticmethod
66
+ def summarize(
67
+ summary_data: List[Tuple[str, List[str]]],
68
+ total_parameters: int,
69
+ trainable_parameters: int,
70
+ model_size: float,
71
+ ) -> None:
72
+ summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data)
73
+ log.info("\n" + summary_table)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/prediction_writer.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ BasePredictionWriter
16
+ ====================
17
+
18
+ Aids in saving predictions
19
+ """
20
+ from typing import Any, Optional, Sequence
21
+
22
+ import pytorch_lightning as pl
23
+ from pytorch_lightning.callbacks.base import Callback
24
+ from pytorch_lightning.utilities import LightningEnum
25
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
26
+
27
+
28
+ class WriteInterval(LightningEnum):
29
+ BATCH = "batch"
30
+ EPOCH = "epoch"
31
+ BATCH_AND_EPOCH = "batch_and_epoch"
32
+
33
+ @property
34
+ def on_batch(self) -> bool:
35
+ return self in (self.BATCH, self.BATCH_AND_EPOCH)
36
+
37
+ @property
38
+ def on_epoch(self) -> bool:
39
+ return self in (self.EPOCH, self.BATCH_AND_EPOCH)
40
+
41
+
42
+ class BasePredictionWriter(Callback):
43
+ """Base class to implement how the predictions should be stored.
44
+
45
+ Args:
46
+ write_interval: When to write.
47
+
48
+ Example::
49
+
50
+ import torch
51
+ from pytorch_lightning.callbacks import BasePredictionWriter
52
+
53
+ class CustomWriter(BasePredictionWriter):
54
+
55
+ def __init__(self, output_dir: str, write_interval: str):
56
+ super().__init__(write_interval)
57
+ self.output_dir
58
+
59
+ def write_on_batch_end(
60
+ self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any,
61
+ batch_idx: int, dataloader_idx: int
62
+ ):
63
+ torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))
64
+
65
+ def write_on_epoch_end(
66
+ self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any]
67
+ ):
68
+ torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
69
+ """
70
+
71
+ def __init__(self, write_interval: str = "batch") -> None:
72
+ if write_interval not in list(WriteInterval):
73
+ raise MisconfigurationException(f"`write_interval` should be one of {[i.value for i in WriteInterval]}.")
74
+ self.interval = WriteInterval(write_interval)
75
+
76
+ def write_on_batch_end(
77
+ self,
78
+ trainer: "pl.Trainer",
79
+ pl_module: "pl.LightningModule",
80
+ prediction: Any,
81
+ batch_indices: Optional[Sequence[int]],
82
+ batch: Any,
83
+ batch_idx: int,
84
+ dataloader_idx: int,
85
+ ) -> None:
86
+ """Override with the logic to write a single batch."""
87
+ raise NotImplementedError()
88
+
89
+ def write_on_epoch_end(
90
+ self,
91
+ trainer: "pl.Trainer",
92
+ pl_module: "pl.LightningModule",
93
+ predictions: Sequence[Any],
94
+ batch_indices: Optional[Sequence[Any]],
95
+ ) -> None:
96
+ """Override with the logic to write all batches."""
97
+ raise NotImplementedError()
98
+
99
+ def on_predict_batch_end(
100
+ self,
101
+ trainer: "pl.Trainer",
102
+ pl_module: "pl.LightningModule",
103
+ outputs: Any,
104
+ batch: Any,
105
+ batch_idx: int,
106
+ dataloader_idx: int,
107
+ ) -> None:
108
+ if not self.interval.on_batch:
109
+ return
110
+ batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices
111
+ self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)
112
+
113
+ def on_predict_epoch_end(
114
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Sequence[Any]
115
+ ) -> None:
116
+ if not self.interval.on_epoch:
117
+ return
118
+ epoch_batch_indices = trainer.predict_loop.epoch_batch_indices
119
+ self.write_on_epoch_end(trainer, pl_module, trainer.predict_loop.predictions, epoch_batch_indices)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/pruning.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ ModelPruning
16
+ ^^^^^^^^^^^^
17
+ """
18
+ import inspect
19
+ import logging
20
+ from copy import deepcopy
21
+ from functools import partial
22
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.utils.prune as pytorch_prune
26
+ from torch import nn
27
+ from typing_extensions import TypedDict
28
+
29
+ import pytorch_lightning as pl
30
+ from pytorch_lightning.callbacks.base import Callback
31
+ from pytorch_lightning.core.lightning import LightningModule
32
+ from pytorch_lightning.utilities.apply_func import apply_to_collection
33
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
34
+ from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only
35
+
36
+ log = logging.getLogger(__name__)
37
+
38
+ _PYTORCH_PRUNING_FUNCTIONS = {
39
+ "ln_structured": pytorch_prune.ln_structured,
40
+ "l1_unstructured": pytorch_prune.l1_unstructured,
41
+ "random_structured": pytorch_prune.random_structured,
42
+ "random_unstructured": pytorch_prune.random_unstructured,
43
+ }
44
+
45
+ _PYTORCH_PRUNING_METHOD = {
46
+ "ln_structured": pytorch_prune.LnStructured,
47
+ "l1_unstructured": pytorch_prune.L1Unstructured,
48
+ "random_structured": pytorch_prune.RandomStructured,
49
+ "random_unstructured": pytorch_prune.RandomUnstructured,
50
+ }
51
+
52
+ _PARAM_TUPLE = Tuple[nn.Module, str]
53
+ _PARAM_LIST = Sequence[_PARAM_TUPLE]
54
+ _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict)
55
+
56
+
57
+ class _LayerRef(TypedDict):
58
+ data: nn.Module
59
+ names: List[Tuple[int, str]]
60
+
61
+
62
+ class ModelPruning(Callback):
63
+ PARAMETER_NAMES = ("weight", "bias")
64
+
65
+ def __init__(
66
+ self,
67
+ pruning_fn: Union[Callable, str],
68
+ parameters_to_prune: _PARAM_LIST = (),
69
+ parameter_names: Optional[List[str]] = None,
70
+ use_global_unstructured: bool = True,
71
+ amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5,
72
+ apply_pruning: Union[bool, Callable[[int], bool]] = True,
73
+ make_pruning_permanent: bool = True,
74
+ use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True,
75
+ resample_parameters: bool = False,
76
+ pruning_dim: Optional[int] = None,
77
+ pruning_norm: Optional[int] = None,
78
+ verbose: int = 0,
79
+ prune_on_train_epoch_end: bool = True,
80
+ ) -> None:
81
+ """Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning
82
+ networks parameters during training.
83
+
84
+ To learn more about pruning with PyTorch, please take a look at
85
+ `this tutorial <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html>`_.
86
+
87
+ .. warning:: ``ModelPruning`` is in beta and subject to change.
88
+
89
+ .. code-block:: python
90
+
91
+ parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")]
92
+
93
+ trainer = Trainer(
94
+ callbacks=[
95
+ ModelPruning(
96
+ pruning_fn="l1_unstructured",
97
+ parameters_to_prune=parameters_to_prune,
98
+ amount=0.01,
99
+ use_global_unstructured=True,
100
+ )
101
+ ]
102
+ )
103
+
104
+ When ``parameters_to_prune`` is ``None``, ``parameters_to_prune`` will contain all parameters from the model.
105
+ The user can override ``filter_parameters_to_prune`` to filter any ``nn.Module`` to be pruned.
106
+
107
+ Args:
108
+
109
+ pruning_fn: Function from torch.nn.utils.prune module or your own PyTorch ``BasePruningMethod`` subclass.
110
+ Can also be string e.g. `"l1_unstructured"`. See pytorch docs for more details.
111
+
112
+ parameters_to_prune: List of tuples ``(nn.Module, "parameter_name_string")``.
113
+
114
+ parameter_names: List of parameter names to be pruned from the nn.Module.
115
+ Can either be ``"weight"`` or ``"bias"``.
116
+
117
+ use_global_unstructured: Whether to apply pruning globally on the model.
118
+ If ``parameters_to_prune`` is provided, global unstructured will be restricted on them.
119
+
120
+ amount: Quantity of parameters to prune:
121
+
122
+ - ``float``. Between 0.0 and 1.0. Represents the fraction of parameters to prune.
123
+ - ``int``. Represents the absolute number of parameters to prune.
124
+ - ``Callable``. For dynamic values. Will be called every epoch. Should return a value.
125
+
126
+ apply_pruning: Whether to apply pruning.
127
+
128
+ - ``bool``. Always apply it or not.
129
+ - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch.
130
+
131
+ make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks
132
+ when training ends or the model is saved.
133
+
134
+ use_lottery_ticket_hypothesis: See `The lottery ticket hypothesis <https://arxiv.org/abs/1803.03635>`_:
135
+
136
+ - ``bool``. Whether to apply it or not.
137
+ - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch.
138
+
139
+ resample_parameters: Used with ``use_lottery_ticket_hypothesis``. If True, the model parameters will
140
+ be resampled, otherwise, the exact original parameters will be used.
141
+
142
+ pruning_dim: If you are using a structured pruning method you need to specify the dimension.
143
+
144
+ pruning_norm: If you are using ``ln_structured`` you need to specify the norm.
145
+
146
+ verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity
147
+
148
+ prune_on_train_epoch_end: whether to apply pruning at the end of the training epoch.
149
+ If this is ``False``, then the check runs at the end of the validation epoch.
150
+
151
+ Raises:
152
+ MisconfigurationException:
153
+ If ``parameter_names`` is neither ``"weight"`` nor ``"bias"``,
154
+ if the provided ``pruning_fn`` is not supported,
155
+ if ``pruning_dim`` is not provided when ``"unstructured"``,
156
+ if ``pruning_norm`` is not provided when ``"ln_structured"``,
157
+ if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or
158
+ if ``amount`` is none of ``int``, ``float`` and ``Callable``.
159
+ """
160
+
161
+ self._use_global_unstructured = use_global_unstructured
162
+ self._parameters_to_prune = parameters_to_prune
163
+ self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
164
+ self._resample_parameters = resample_parameters
165
+ self._prune_on_train_epoch_end = prune_on_train_epoch_end
166
+ self._parameter_names = parameter_names or self.PARAMETER_NAMES
167
+ self._global_kwargs: Dict[str, Any] = {}
168
+ self._original_layers: Optional[Dict[int, _LayerRef]] = None
169
+ self._pruning_method_name: Optional[str] = None
170
+
171
+ for name in self._parameter_names:
172
+ if name not in self.PARAMETER_NAMES:
173
+ raise MisconfigurationException(
174
+ f"The provided `parameter_names` name: {name} isn't in {self.PARAMETER_NAMES}"
175
+ )
176
+
177
+ if isinstance(pruning_fn, str):
178
+ pruning_kwargs = {}
179
+ pruning_fn = pruning_fn.lower()
180
+ if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS:
181
+ raise MisconfigurationException(
182
+ f"The provided `pruning_fn` {pruning_fn} isn't available in PyTorch's"
183
+ f" built-in functions: {list(_PYTORCH_PRUNING_FUNCTIONS.keys())} "
184
+ )
185
+ if pruning_fn.endswith("_structured"):
186
+ if pruning_dim is None:
187
+ raise MisconfigurationException(
188
+ "When requesting `structured` pruning, the `pruning_dim` should be provided."
189
+ )
190
+ if pruning_fn == "ln_structured":
191
+ if pruning_norm is None:
192
+ raise MisconfigurationException(
193
+ "When requesting `ln_structured` pruning, the `pruning_norm` should be provided."
194
+ )
195
+ pruning_kwargs["n"] = pruning_norm
196
+ pruning_kwargs["dim"] = pruning_dim
197
+ pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs)
198
+ elif self._is_pruning_method(pruning_fn):
199
+ if not use_global_unstructured:
200
+ raise MisconfigurationException(
201
+ "PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`."
202
+ )
203
+ else:
204
+ raise MisconfigurationException(
205
+ f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}"
206
+ f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}."
207
+ " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance"
208
+ )
209
+
210
+ # need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute
211
+ if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore
212
+ raise MisconfigurationException(
213
+ 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore
214
+ f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
215
+ )
216
+
217
+ self.pruning_fn = pruning_fn
218
+ self._apply_pruning = apply_pruning
219
+ self._make_pruning_permanent = make_pruning_permanent
220
+
221
+ if not (isinstance(amount, (int, float)) or callable(amount)):
222
+ raise MisconfigurationException(
223
+ "`amount` should be provided and be either an int, a float or Callable function."
224
+ )
225
+
226
+ self.amount = amount
227
+
228
+ if verbose not in (0, 1, 2):
229
+ raise MisconfigurationException("`verbose` must be any of (0, 1, 2)")
230
+
231
+ self._verbose = verbose
232
+
233
+ def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST:
234
+ """This function can be overridden to control which module to prune."""
235
+ return parameters_to_prune
236
+
237
+ def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]:
238
+ """This function takes `pruning_fn`, a function name.
239
+
240
+ IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE,
241
+ pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
242
+ """
243
+ pruning_meth = (
244
+ _PYTORCH_PRUNING_METHOD[pruning_fn]
245
+ if self._use_global_unstructured
246
+ else _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
247
+ )
248
+ assert callable(pruning_meth), "Selected pruning method is not callable"
249
+ if self._use_global_unstructured:
250
+ self._global_kwargs = kwargs
251
+ # save the function __name__ now because partial does not include it
252
+ # and there are issues setting the attribute manually in ddp.
253
+ self._pruning_method_name = pruning_meth.__name__
254
+ if self._use_global_unstructured:
255
+ return pruning_meth
256
+ return ModelPruning._wrap_pruning_fn(pruning_meth, **kwargs)
257
+
258
+ @staticmethod
259
+ def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable:
260
+ return partial(pruning_fn, **kwargs)
261
+
262
+ def make_pruning_permanent(self, module: nn.Module) -> None:
263
+ """Removes pruning buffers from any pruned modules.
264
+
265
+ Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
266
+ """
267
+ for _, module in module.named_modules():
268
+ for k in list(module._forward_pre_hooks):
269
+ hook = module._forward_pre_hooks[k]
270
+ if isinstance(hook, pytorch_prune.BasePruningMethod):
271
+ hook.remove(module)
272
+ del module._forward_pre_hooks[k]
273
+
274
+ @staticmethod
275
+ def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
276
+ dst = getattr(new, name)
277
+ src = getattr(old, name)
278
+ if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor):
279
+ return
280
+ dst.data = src.data.to(dst.device)
281
+
282
+ def apply_lottery_ticket_hypothesis(self) -> None:
283
+ r"""
284
+ Lottery ticket hypothesis algorithm (see page 2 of the paper):
285
+
286
+ 1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`).
287
+ 2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`.
288
+ 3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`.
289
+ 4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`.
290
+
291
+ This function implements the step 4.
292
+
293
+ The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta`
294
+ """ # noqa: E501
295
+ assert self._original_layers is not None
296
+ for d in self._original_layers.values():
297
+ copy = d["data"]
298
+ names = d["names"]
299
+ if self._resample_parameters and hasattr(copy, "reset_parameters") and callable(copy.reset_parameters):
300
+ copy = deepcopy(copy) # keep the original parameters
301
+ copy.reset_parameters()
302
+ for i, name in names:
303
+ new, new_name = self._parameters_to_prune[i]
304
+ self._copy_param(new, copy, name)
305
+
306
+ def _apply_local_pruning(self, amount: float) -> None:
307
+ for module, name in self._parameters_to_prune:
308
+ self.pruning_fn(module, name=name, amount=amount)
309
+
310
+ def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]:
311
+ self._global_kwargs["amount"] = amount
312
+ params = set(inspect.signature(self.pruning_fn).parameters)
313
+ params.discard("self")
314
+ return {k: v for k, v in self._global_kwargs.items() if k in params}
315
+
316
+ def _apply_global_pruning(self, amount: float) -> None:
317
+ pytorch_prune.global_unstructured(
318
+ self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount)
319
+ )
320
+
321
+ @staticmethod
322
+ def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]:
323
+ attr = f"{name}_mask"
324
+ if not hasattr(module, attr):
325
+ return 0, 1
326
+ mask = getattr(module, attr)
327
+ return (mask == 0).sum().item(), mask.numel()
328
+
329
+ def apply_pruning(self, amount: Union[int, float]) -> None:
330
+ """Applies pruning to ``parameters_to_prune``."""
331
+ if self._verbose:
332
+ prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
333
+
334
+ if self._use_global_unstructured:
335
+ self._apply_global_pruning(amount)
336
+ else:
337
+ self._apply_local_pruning(amount)
338
+
339
+ if self._verbose:
340
+ curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
341
+ self._log_sparsity_stats(prev_stats, curr_stats, amount=amount)
342
+
343
+ @rank_zero_only
344
+ def _log_sparsity_stats(
345
+ self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0
346
+ ) -> None:
347
+ total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
348
+ prev_total_zeros = sum(zeros for zeros, _ in prev)
349
+ curr_total_zeros = sum(zeros for zeros, _ in curr)
350
+ log.info(
351
+ f"Applied `{self._pruning_method_name}`. Pruned:"
352
+ f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->"
353
+ f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})"
354
+ )
355
+ if self._verbose == 2:
356
+ for i, (module, name) in enumerate(self._parameters_to_prune):
357
+ prev_mask_zeros, prev_mask_size = prev[i]
358
+ curr_mask_zeros, curr_mask_size = curr[i]
359
+ log.info(
360
+ f"Applied `{self._pruning_method_name}` to `{module!r}.{name}` with amount={amount}. Pruned:"
361
+ f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->"
362
+ f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
363
+ )
364
+
365
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
366
+ parameters_to_prune = self.sanitize_parameters_to_prune(
367
+ pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
368
+ )
369
+
370
+ self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune)
371
+
372
+ if self._use_lottery_ticket_hypothesis:
373
+ # group modules by id. Each entry has a copy of the initial data
374
+ # and a list of the associated parameter names to prune
375
+ self._original_layers = {}
376
+ for i, (module, name) in enumerate(self._parameters_to_prune):
377
+ id_ = id(module)
378
+ self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[]))
379
+ self._original_layers[id_]["names"].append((i, name))
380
+
381
+ def _run_pruning(self, current_epoch: int) -> None:
382
+ prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning
383
+ amount = self.amount(current_epoch) if callable(self.amount) else self.amount
384
+ if not prune or not amount:
385
+ return
386
+ self.apply_pruning(amount)
387
+
388
+ if (
389
+ self._use_lottery_ticket_hypothesis(current_epoch)
390
+ if callable(self._use_lottery_ticket_hypothesis)
391
+ else self._use_lottery_ticket_hypothesis
392
+ ):
393
+ self.apply_lottery_ticket_hypothesis()
394
+
395
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None:
396
+ if self._prune_on_train_epoch_end:
397
+ rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning")
398
+ self._run_pruning(pl_module.current_epoch)
399
+
400
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
401
+ if not trainer.sanity_checking and not self._prune_on_train_epoch_end:
402
+ rank_zero_debug("`ModelPruning.on_validation_epoch_end`. Applying pruning")
403
+ self._run_pruning(pl_module.current_epoch)
404
+
405
+ def on_train_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None:
406
+ if self._make_pruning_permanent:
407
+ rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint")
408
+ self.make_pruning_permanent(pl_module)
409
+
410
+ def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> Dict[str, Any]:
411
+ state_dict = pl_module.state_dict()
412
+
413
+ # find the mask and the original weights.
414
+ map_pruned_params = {k.replace("_mask", "") for k in state_dict.keys() if k.endswith("_mask")}
415
+ for tensor_name in map_pruned_params:
416
+ orig = state_dict.pop(tensor_name + "_orig")
417
+ mask = state_dict.pop(tensor_name + "_mask")
418
+ # make weights permanent
419
+ state_dict[tensor_name] = mask.to(dtype=orig.dtype) * orig
420
+
421
+ def move_to_cpu(tensor: torch.Tensor) -> torch.Tensor:
422
+ # each tensor and move them on cpu
423
+ return tensor.cpu()
424
+
425
+ return apply_to_collection(state_dict, torch.Tensor, move_to_cpu)
426
+
427
+ def on_save_checkpoint(
428
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
429
+ ) -> Optional[dict]:
430
+ if self._make_pruning_permanent:
431
+ rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")
432
+ # manually prune the weights so training can keep going with the same buffers
433
+ checkpoint["state_dict"] = self._make_pruning_permanent_on_state_dict(pl_module)
434
+
435
+ @staticmethod
436
+ def sanitize_parameters_to_prune(
437
+ pl_module: LightningModule, parameters_to_prune: _PARAM_LIST = (), parameter_names: Sequence[str] = ()
438
+ ) -> _PARAM_LIST:
439
+ """This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If
440
+ ``parameters_to_prune is None``, it will be generated with all parameters of the model.
441
+
442
+ Raises:
443
+ MisconfigurationException:
444
+ If ``parameters_to_prune`` doesn't exist in the model, or
445
+ if ``parameters_to_prune`` is neither a list nor a tuple.
446
+ """
447
+ parameters = parameter_names or ModelPruning.PARAMETER_NAMES
448
+
449
+ current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]
450
+
451
+ if not parameters_to_prune:
452
+ parameters_to_prune = [
453
+ (m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
454
+ ]
455
+ elif (
456
+ isinstance(parameters_to_prune, (list, tuple))
457
+ and len(parameters_to_prune) > 0
458
+ and all(len(p) == 2 for p in parameters_to_prune)
459
+ and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune)
460
+ ):
461
+ missing_modules, missing_parameters = [], []
462
+ for module, name in parameters_to_prune:
463
+ if module not in current_modules:
464
+ missing_modules.append(module)
465
+ continue
466
+ if not hasattr(module, name):
467
+ missing_parameters.append(name)
468
+
469
+ if missing_modules or missing_parameters:
470
+ raise MisconfigurationException(
471
+ "Some provided `parameters_to_tune` don't exist in the model."
472
+ f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}"
473
+ )
474
+ else:
475
+ raise MisconfigurationException(
476
+ "The provided `parameters_to_prune` should either be list of tuple"
477
+ " with 2 elements: (nn.Module, parameter_name_to_prune) or None"
478
+ )
479
+
480
+ return parameters_to_prune
481
+
482
+ @staticmethod
483
+ def _is_pruning_method(method: Any) -> bool:
484
+ if not inspect.isclass(method):
485
+ return False
486
+ return issubclass(method, pytorch_prune.BasePruningMethod)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/quantization.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ Quantization
16
+ ^^^^^^^^^^^^
17
+
18
+ """
19
+ import copy
20
+ import functools
21
+ from typing import Any, Callable, Dict, Optional, Sequence, Union
22
+
23
+ import torch
24
+ from torch import Tensor
25
+ from torch.quantization import FakeQuantizeBase
26
+
27
+ import pytorch_lightning as pl
28
+ from pytorch_lightning.callbacks.base import Callback
29
+ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11
30
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
31
+
32
+ if _TORCH_GREATER_EQUAL_1_10:
33
+ from torch.ao.quantization.qconfig import QConfig
34
+ else:
35
+ from torch.quantization import QConfig
36
+
37
+ if _TORCH_GREATER_EQUAL_1_11:
38
+ from torch.ao.quantization import fuse_modules_qat as fuse_modules
39
+ else:
40
+ from torch.quantization import fuse_modules
41
+
42
+
43
+ def wrap_qat_forward_context(
44
+ quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
45
+ ) -> Callable:
46
+ """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
47
+ compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
48
+ training all the time."""
49
+ # todo: consider using registering hook before/after forward
50
+ @functools.wraps(func)
51
+ def wrapper(data) -> Any:
52
+ _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
53
+ _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
54
+ _quant_run = trigger_condition is None or _is_func_true or _is_count_true
55
+ # apply custom trigger
56
+ if _quant_run:
57
+ quant_cb._forward_calls += 1
58
+ data = model.quant(data)
59
+ data = func(data)
60
+ # apply custom trigger
61
+ if _quant_run:
62
+ data = model.dequant(data)
63
+ return data
64
+
65
+ return wrapper
66
+
67
+
68
+ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -> Callable:
69
+ """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
70
+ compatibility."""
71
+ # todo: consider using registering hook before/after forward
72
+ @functools.wraps(func)
73
+ def wrapper(data) -> Any:
74
+ data = model.quant(data)
75
+ data = func(data)
76
+ data = model.dequant(data)
77
+ return data
78
+
79
+ return wrapper
80
+
81
+
82
+ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:
83
+ """recursive check if model has some layers denoted with '.'."""
84
+ if "." in attribs:
85
+ attrib, attribs = attribs.split(".", 1)
86
+ if hasattr(obj, attrib):
87
+ return _recursive_hasattr(getattr(obj, attrib), attribs, state)
88
+ return False
89
+ return state and hasattr(obj, attribs)
90
+
91
+
92
+ class QuantizationAwareTraining(Callback):
93
+ """Quantization allows speeding up inference and decreasing memory requirements by performing computations and
94
+ storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native
95
+ PyTorch API so for more information see `PyTorch Quantization`_.
96
+
97
+ .. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
98
+
99
+ The ``LightningModule`` is prepared for QAT training in the ``on_fit_start`` hook. Checkpoints saved during training
100
+ include already collected stats to perform the Quantization conversion, but it doesn't contain the quantized or
101
+ fused model/layers. The quantization is performed in the ``on_fit_end`` hook so the model needs to be saved after
102
+ training finishes if quantization is desired.
103
+
104
+ Args:
105
+
106
+ qconfig: quantization configuration:
107
+
108
+ - 'fbgemm' for server inference.
109
+ - 'qnnpack' for mobile inference.
110
+ - a custom `torch.quantization.QConfig`_.
111
+
112
+ observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
113
+ and ``HistogramObserver`` as "histogram" which is more computationally expensive.
114
+
115
+ collect_quantization: count or custom function to collect quantization statistics:
116
+
117
+ - ``None`` (default). The quantization observer is called in each module forward
118
+ (useful for collecting extended statistic when using image/data augmentation).
119
+ - ``int``. Use to set a fixed number of calls, starting from the beginning.
120
+ - ``Callable``. Custom function with single trainer argument.
121
+ See this example to trigger only the last epoch:
122
+
123
+ .. code-block:: python
124
+
125
+ def custom_trigger_last(trainer):
126
+ return trainer.current_epoch == (trainer.max_epochs - 1)
127
+
128
+
129
+ QuantizationAwareTraining(collect_quantization=custom_trigger_last)
130
+
131
+ modules_to_fuse: allows you fuse a few layers together as shown in
132
+ `diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
133
+ to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
134
+
135
+ input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
136
+ but break compatibility to torchscript and export with ``torch.save``.
137
+
138
+ quantize_on_fit_end: perform the quantization in `on_fit_end`.
139
+ Note that once converted, the model cannot be put in training mode again.
140
+
141
+ observer_enabled_stages: allow fake-quantization modules' observers to do calibration during provided stages:
142
+
143
+ - ``'train'``: the observers can do calibration during training.
144
+ - ``'validate'``: the observers can do calibration during validating.
145
+ Note that we don't disable observers during the sanity check as the model hasn't been calibrated with
146
+ training data yet. After the sanity check, the fake-quantization modules are restored to initial states.
147
+ - ``'test'``: the observers can do calibration during testing.
148
+ - ``'predict'``: the observers can do calibration during predicting.
149
+
150
+ Note that we only handle observers belonging to fake-quantization modules. When ``qconfig`` is a ``str`` and
151
+ ``observer_type`` is ``'histogram'``, the observers won't belong to any fake-quantization modules and will
152
+ not be controlled by the callback.
153
+
154
+ .. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training
155
+ .. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig
156
+ """
157
+
158
+ OBSERVER_TYPES = ("histogram", "average")
159
+ OBSERVER_STAGES = ("train", "validate", "test", "predict")
160
+
161
+ def __init__(
162
+ self,
163
+ qconfig: Union[str, QConfig] = "fbgemm",
164
+ observer_type: str = "average",
165
+ collect_quantization: Optional[Union[int, Callable]] = None,
166
+ modules_to_fuse: Optional[Sequence] = None,
167
+ input_compatible: bool = True,
168
+ quantize_on_fit_end: bool = True,
169
+ observer_enabled_stages: Sequence[str] = ("train",),
170
+ ) -> None:
171
+ _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
172
+ if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
173
+ raise MisconfigurationException(
174
+ f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}"
175
+ )
176
+ self._qconfig = qconfig
177
+
178
+ if observer_type not in self.OBSERVER_TYPES:
179
+ raise MisconfigurationException(
180
+ f'Unsupported observer type "{observer_type}", allowed are {self.OBSERVER_TYPES}.'
181
+ )
182
+ self._observer_type = observer_type
183
+
184
+ if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
185
+ raise MisconfigurationException(
186
+ f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
187
+ )
188
+ self._collect_quantization = collect_quantization
189
+
190
+ self._modules_to_fuse = modules_to_fuse
191
+ self._input_compatible = input_compatible
192
+ self._convert_on_fit_end = quantize_on_fit_end
193
+
194
+ observer_enabled_stages = set(observer_enabled_stages)
195
+ unsupported_stages = observer_enabled_stages - set(self.OBSERVER_STAGES)
196
+ if unsupported_stages:
197
+ raise MisconfigurationException(
198
+ f'Unsupported stages "{tuple(sorted(unsupported_stages))}", allowed are {self.OBSERVER_STAGES}.'
199
+ )
200
+ self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
201
+
202
+ self._forward_calls = 0
203
+ self._fake_quant_to_initial_state_dict = {}
204
+ self._last_fake_quant_to_observer_enabled = {}
205
+ self._module_prepared = False
206
+
207
+ def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
208
+ if not self._modules_to_fuse:
209
+ return False
210
+ for group in self._modules_to_fuse:
211
+ if not all(_recursive_hasattr(model, m) for m in group):
212
+ raise MisconfigurationException(
213
+ f"You have requested to fuse {group} but one or more of them is not your model attributes"
214
+ )
215
+ return True
216
+
217
+ def _collect_observer_enabled(self) -> Dict[FakeQuantizeBase, Tensor]:
218
+ return {
219
+ fake_quant: fake_quant.observer_enabled.clone() for fake_quant in self._fake_quant_to_initial_state_dict
220
+ }
221
+
222
+ def _disable_observer(self, pl_module: "pl.LightningModule") -> None:
223
+ self._last_fake_quant_to_observer_enabled = self._collect_observer_enabled()
224
+ pl_module.apply(torch.quantization.disable_observer)
225
+
226
+ def _restore_last_observer_enabled(self) -> None:
227
+ for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
228
+ fake_quant.observer_enabled.copy_(observer_enabled)
229
+
230
+ def _prepare_model(self, model: torch.nn.Module) -> None:
231
+ if self._module_prepared:
232
+ return
233
+ # QuantStub converts tensors from floating point to quantized
234
+ model.quant = torch.quantization.QuantStub()
235
+ # DeQuantStub converts tensors from quantized to floating point
236
+ model.dequant = torch.quantization.DeQuantStub()
237
+ # manually specify where tensors will be converted from quantized
238
+ # to floating point in the quantized model
239
+ self.__module_forward = model.forward
240
+ model.forward = wrap_qat_forward_context(
241
+ quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
242
+ )
243
+
244
+ # attach a global qconfig, which contains information about what kind
245
+ # of observers to attach. Use 'fbgemm' for server inference
246
+ if isinstance(self._qconfig, str):
247
+ if self._observer_type == "histogram":
248
+ model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
249
+ elif self._observer_type == "average":
250
+ # version=None corresponds to using FakeQuantize rather than
251
+ # FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
252
+ # details in https://github.com/pytorch/pytorch/issues/64564
253
+ extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
254
+ model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
255
+
256
+ elif isinstance(self._qconfig, QConfig):
257
+ model.qconfig = self._qconfig
258
+
259
+ if self._check_feasible_fuse(model):
260
+ fuse_modules(model, self._modules_to_fuse, inplace=True)
261
+
262
+ # Prepare the model for QAT. This inserts observers and fake_quants in
263
+ # the model that will observe weight and activation tensors during calibration.
264
+ torch.quantization.prepare_qat(model, inplace=True)
265
+
266
+ fake_quants = tuple(module for module in model.modules() if isinstance(module, FakeQuantizeBase))
267
+ self._fake_quant_to_initial_state_dict = {
268
+ fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants
269
+ }
270
+ self._module_prepared = True
271
+
272
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
273
+ self._prepare_model(pl_module)
274
+
275
+ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
276
+ if not self._convert_on_fit_end:
277
+ pl_module.forward = self.__module_forward
278
+ return
279
+ pl_module.eval()
280
+ # Convert the observed model to a quantized model. This does several things:
281
+ # quantizes the weights, computes and stores the scale and bias value to be
282
+ # used with each activation tensor, fuses modules where appropriate,
283
+ # and replaces key operators with quantized implementations.
284
+ torch.quantization.convert(pl_module, inplace=True)
285
+ # check we shall preserve wrapper
286
+ if self._input_compatible:
287
+ pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
288
+ else:
289
+ pl_module.forward = self.__module_forward
290
+
291
+ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
292
+ if "train" in self._observer_disabled_stages:
293
+ self._disable_observer(pl_module)
294
+
295
+ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
296
+ if "train" in self._observer_disabled_stages:
297
+ self._restore_last_observer_enabled()
298
+
299
+ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
300
+ if "validate" in self._observer_disabled_stages and not trainer.sanity_checking:
301
+ # ``torch.quantization.MovingAveragePerChannelMinMaxObserver`` and ``torch.quantization.HistogramObserver``
302
+ # need to see at least one batch to infer the shapes of quantization ``scale`` and ``zero_point``. So we
303
+ # don't disable observers during the sanity check so that they can infer the shapes of quantization
304
+ # parameters with validation data.
305
+ self._disable_observer(pl_module)
306
+
307
+ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
308
+ if "validate" in self._observer_disabled_stages:
309
+ if trainer.sanity_checking:
310
+ for fake_quant, state_dict in self._fake_quant_to_initial_state_dict.items():
311
+ fake_quant.load_state_dict(state_dict)
312
+ else:
313
+ self._restore_last_observer_enabled()
314
+
315
+ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
316
+ if "test" in self._observer_disabled_stages:
317
+ self._disable_observer(pl_module)
318
+
319
+ def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
320
+ if "test" in self._observer_disabled_stages:
321
+ self._restore_last_observer_enabled()
322
+
323
+ def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
324
+ if "predict" in self._observer_disabled_stages:
325
+ self._disable_observer(pl_module)
326
+
327
+ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
328
+ if "predict" in self._observer_disabled_stages:
329
+ self._restore_last_observer_enabled()
330
+
331
+ def state_dict(self) -> Dict[str, Any]:
332
+ keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
333
+ return {n: getattr(self, n) for n in keys}
334
+
335
+ def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
336
+ """Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
337
+
338
+ This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
339
+ after the model has already loaded the weights. For quantization, we need to convert the model first before that
340
+ happens, assuming the previous training used quantization.
341
+ """
342
+ for k, v in state_dict.items():
343
+ setattr(self, k, v)
344
+ self._prepare_model(model)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/rich_model_summary.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Tuple
15
+
16
+ from pytorch_lightning.callbacks import ModelSummary
17
+ from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
18
+ from pytorch_lightning.utilities.model_summary import get_human_readable_count
19
+
20
+ if _RICH_AVAILABLE:
21
+ from rich import get_console
22
+ from rich.table import Table
23
+
24
+
25
+ class RichModelSummary(ModelSummary):
26
+ r"""
27
+ Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`
28
+ with `rich text formatting <https://github.com/willmcgugan/rich>`_.
29
+
30
+ Install it with pip:
31
+
32
+ .. code-block:: bash
33
+
34
+ pip install rich
35
+
36
+ .. code-block:: python
37
+
38
+ from pytorch_lightning import Trainer
39
+ from pytorch_lightning.callbacks import RichModelSummary
40
+
41
+ trainer = Trainer(callbacks=RichModelSummary())
42
+
43
+ You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar`
44
+
45
+ .. code-block:: python
46
+
47
+ from pytorch_lightning import Trainer
48
+ from pytorch_lightning.callbacks import RichProgressBar
49
+
50
+ trainer = Trainer(callbacks=RichProgressBar())
51
+
52
+ Args:
53
+ max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
54
+ layer summary off.
55
+
56
+ Raises:
57
+ ModuleNotFoundError:
58
+ If required `rich` package is not installed on the device.
59
+ """
60
+
61
+ def __init__(self, max_depth: int = 1) -> None:
62
+ if not _RICH_AVAILABLE:
63
+ raise ModuleNotFoundError(
64
+ "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install -U rich`."
65
+ )
66
+ super().__init__(max_depth)
67
+
68
+ @staticmethod
69
+ def summarize(
70
+ summary_data: List[Tuple[str, List[str]]],
71
+ total_parameters: int,
72
+ trainable_parameters: int,
73
+ model_size: float,
74
+ ) -> None:
75
+
76
+ console = get_console()
77
+
78
+ table = Table(header_style="bold magenta")
79
+ table.add_column(" ", style="dim")
80
+ table.add_column("Name", justify="left", no_wrap=True)
81
+ table.add_column("Type")
82
+ table.add_column("Params", justify="right")
83
+
84
+ column_names = list(zip(*summary_data))[0]
85
+
86
+ for column_name in ["In sizes", "Out sizes"]:
87
+ if column_name in column_names:
88
+ table.add_column(column_name, justify="right", style="white")
89
+
90
+ rows = list(zip(*(arr[1] for arr in summary_data)))
91
+ for row in rows:
92
+ table.add_row(*row)
93
+
94
+ console.print(table)
95
+
96
+ parameters = []
97
+ for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
98
+ parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
99
+
100
+ grid = Table.grid(expand=True)
101
+ grid.add_column()
102
+ grid.add_column()
103
+
104
+ grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}")
105
+ grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
106
+ grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
107
+ grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
108
+
109
+ console.print(grid)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/stochastic_weight_avg.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ Stochastic Weight Averaging Callback
16
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17
+ """
18
+ from copy import deepcopy
19
+ from typing import Callable, List, Optional, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+ from torch.optim.swa_utils import SWALR
24
+
25
+ import pytorch_lightning as pl
26
+ from pytorch_lightning.callbacks.base import Callback
27
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
28
+ from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
29
+ from pytorch_lightning.utilities.types import LRSchedulerConfig
30
+
31
+ _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
32
+
33
+
34
+ class StochasticWeightAveraging(Callback):
35
+ def __init__(
36
+ self,
37
+ swa_epoch_start: Union[int, float] = 0.8,
38
+ swa_lrs: Optional[Union[float, List[float]]] = None,
39
+ annealing_epochs: int = 10,
40
+ annealing_strategy: str = "cos",
41
+ avg_fn: Optional[_AVG_FN] = None,
42
+ device: Optional[Union[torch.device, str]] = torch.device("cpu"),
43
+ ):
44
+ r"""
45
+
46
+ Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
47
+
48
+ Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to
49
+ Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii
50
+ Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
51
+ (UAI 2018).
52
+
53
+ This documentation is highly inspired by PyTorch's work on SWA.
54
+ The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package.
55
+
56
+ For a SWA explanation, please take a look
57
+ `here <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`_.
58
+
59
+ .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change.
60
+
61
+ .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.
62
+
63
+ .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.
64
+
65
+ See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
66
+
67
+ Arguments:
68
+
69
+ swa_epoch_start: If provided as int, the procedure will start from
70
+ the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
71
+ the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch
72
+
73
+ swa_lrs: The SWA learning rate to use:
74
+
75
+ - ``None``. Use the current learning rate of the optimizer at the time the SWA procedure starts.
76
+ - ``float``. Use this value for all parameter groups of the optimizer.
77
+ - ``List[float]``. A list values for each parameter group of the optimizer.
78
+
79
+ annealing_epochs: number of epochs in the annealing phase (default: 10)
80
+
81
+ annealing_strategy: Specifies the annealing strategy (default: "cos"):
82
+
83
+ - ``"cos"``. For cosine annealing.
84
+ - ``"linear"`` For linear annealing
85
+
86
+ avg_fn: the averaging function used to update the parameters;
87
+ the function must take in the current value of the
88
+ :class:`AveragedModel` parameter, the current value of :attr:`model`
89
+ parameter and the number of models already averaged; if None,
90
+ equally weighted average is used (default: ``None``)
91
+
92
+ device: if provided, the averaged model will be stored on the ``device``.
93
+ When None is provided, it will infer the `device` from ``pl_module``.
94
+ (default: ``"cpu"``)
95
+
96
+ """
97
+
98
+ err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
99
+ if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:
100
+ raise MisconfigurationException(err_msg)
101
+ if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):
102
+ raise MisconfigurationException(err_msg)
103
+
104
+ wrong_type = not isinstance(swa_lrs, (float, list))
105
+ wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
106
+ wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
107
+ if swa_lrs is not None and (wrong_type or wrong_float or wrong_list):
108
+ raise MisconfigurationException(
109
+ "The `swa_lrs` should be `None`, a positive float, or a list of positive floats"
110
+ )
111
+
112
+ if avg_fn is not None and not isinstance(avg_fn, Callable):
113
+ raise MisconfigurationException("The `avg_fn` should be callable.")
114
+
115
+ if device is not None and not isinstance(device, (torch.device, str)):
116
+ raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")
117
+
118
+ self._swa_epoch_start = swa_epoch_start
119
+ self._swa_lrs = swa_lrs
120
+ self._annealing_epochs = annealing_epochs
121
+ self._annealing_strategy = annealing_strategy
122
+ self._avg_fn = avg_fn or self.avg_fn
123
+ self._device = device
124
+ self._model_contains_batch_norm = None
125
+ self._average_model = None
126
+
127
+ @property
128
+ def swa_start(self) -> int:
129
+ return max(self._swa_epoch_start - 1, 0) # 0-based
130
+
131
+ @property
132
+ def swa_end(self) -> int:
133
+ return self._max_epochs - 1 # 0-based
134
+
135
+ @staticmethod
136
+ def pl_module_contains_batch_norm(pl_module: "pl.LightningModule"):
137
+ return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
138
+
139
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
140
+ # copy the model before moving it to accelerator device.
141
+ with pl_module._prevent_trainer_and_dataloaders_deepcopy():
142
+ self._average_model = deepcopy(pl_module)
143
+
144
+ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
145
+ if len(trainer.optimizers) != 1:
146
+ raise MisconfigurationException("SWA currently works with 1 `optimizer`.")
147
+
148
+ if len(trainer.lr_scheduler_configs) > 1:
149
+ raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
150
+
151
+ if isinstance(self._swa_epoch_start, float):
152
+ self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
153
+
154
+ self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)
155
+
156
+ self._max_epochs = trainer.max_epochs
157
+ if self._model_contains_batch_norm:
158
+ # virtually increase max_epochs to perform batch norm update on latest epoch.
159
+ trainer.fit_loop.max_epochs += 1
160
+
161
+ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
162
+ if trainer.current_epoch == self.swa_start:
163
+ # move average model to request device.
164
+ self._average_model = self._average_model.to(self._device or pl_module.device)
165
+
166
+ optimizer = trainer.optimizers[0]
167
+ if self._swa_lrs is None:
168
+ self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups]
169
+ if isinstance(self._swa_lrs, float):
170
+ self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)
171
+
172
+ for lr, group in zip(self._swa_lrs, optimizer.param_groups):
173
+ group["initial_lr"] = lr
174
+
175
+ self._swa_scheduler = SWALR(
176
+ optimizer,
177
+ swa_lr=self._swa_lrs,
178
+ anneal_epochs=self._annealing_epochs,
179
+ anneal_strategy=self._annealing_strategy,
180
+ last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
181
+ )
182
+ # We assert that there is only one optimizer on fit start, so know opt_idx is always 0
183
+ default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
184
+ assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1
185
+
186
+ if trainer.lr_scheduler_configs:
187
+ scheduler_cfg = trainer.lr_scheduler_configs[0]
188
+ if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1:
189
+ rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
190
+ rank_zero_info(
191
+ f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"
192
+ f" for `{self._swa_scheduler.__class__.__name__}`"
193
+ )
194
+ trainer.lr_scheduler_configs[0] = default_scheduler_cfg
195
+ else:
196
+ trainer.lr_scheduler_configs.append(default_scheduler_cfg)
197
+
198
+ self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
199
+
200
+ if self.swa_start <= trainer.current_epoch <= self.swa_end:
201
+ self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
202
+
203
+ # Note: No > here in case the callback is saved with the model and training continues
204
+ if trainer.current_epoch == self.swa_end + 1:
205
+
206
+ # Transfer weights from average model to pl_module
207
+ self.transfer_weights(self._average_model, pl_module)
208
+
209
+ # Reset BatchNorm for update
210
+ self.reset_batch_norm_and_save_state(pl_module)
211
+
212
+ # There is no need to perform either backward or optimizer.step as we are
213
+ # performing only one pass over the train data-loader to compute activation statistics
214
+ # Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
215
+ trainer.num_training_batches += 1
216
+ trainer.fit_loop._skip_backward = True
217
+ self._accumulate_grad_batches = trainer.accumulate_grad_batches
218
+
219
+ trainer.accumulate_grad_batches = trainer.num_training_batches
220
+
221
+ def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
222
+ trainer.fit_loop._skip_backward = False
223
+
224
+ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
225
+ # the trainer increases the current epoch before this hook is called
226
+ if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
227
+ # BatchNorm epoch update. Reset state
228
+ trainer.accumulate_grad_batches = self._accumulate_grad_batches
229
+ trainer.num_training_batches -= 1
230
+ trainer.fit_loop.max_epochs -= 1
231
+ self.reset_momenta()
232
+ elif trainer.current_epoch - 1 == self.swa_end:
233
+ # Last SWA epoch. Transfer weights from average model to pl_module
234
+ self.transfer_weights(self._average_model, pl_module)
235
+
236
+ @staticmethod
237
+ def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"):
238
+ for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
239
+ dst_param.detach().copy_(src_param.to(dst_param.device))
240
+
241
+ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"):
242
+ """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
243
+ self.momenta = {}
244
+ for module in pl_module.modules():
245
+ if not isinstance(module, nn.modules.batchnorm._BatchNorm):
246
+ continue
247
+ module.running_mean = torch.zeros_like(
248
+ module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
249
+ )
250
+ module.running_var = torch.ones_like(
251
+ module.running_var, device=pl_module.device, dtype=module.running_var.dtype
252
+ )
253
+ self.momenta[module] = module.momentum
254
+ module.momentum = None
255
+ module.num_batches_tracked *= 0
256
+
257
+ def reset_momenta(self):
258
+ """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
259
+ for bn_module in self.momenta:
260
+ bn_module.momentum = self.momenta[bn_module]
261
+
262
+ @staticmethod
263
+ def update_parameters(
264
+ average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: torch.LongTensor, avg_fn: _AVG_FN
265
+ ):
266
+ """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112."""
267
+ for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
268
+ device = p_swa.device
269
+ p_swa_ = p_swa.detach()
270
+ p_model_ = p_model.detach().to(device)
271
+ src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))
272
+ p_swa_.copy_(src)
273
+ n_averaged += 1
274
+
275
+ @staticmethod
276
+ def avg_fn(
277
+ averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
278
+ ) -> torch.FloatTensor:
279
+ """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
280
+ return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/timer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ r"""
15
+ Timer
16
+ ^^^^^
17
+ """
18
+ import logging
19
+ import time
20
+ from datetime import timedelta
21
+ from typing import Any, Dict, Optional, Union
22
+
23
+ import pytorch_lightning as pl
24
+ from pytorch_lightning.callbacks.base import Callback
25
+ from pytorch_lightning.trainer.states import RunningStage
26
+ from pytorch_lightning.utilities import LightningEnum
27
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
28
+ from pytorch_lightning.utilities.rank_zero import rank_zero_info
29
+
30
+ log = logging.getLogger(__name__)
31
+
32
+
33
+ class Interval(LightningEnum):
34
+ step = "step"
35
+ epoch = "epoch"
36
+
37
+
38
+ class Timer(Callback):
39
+ """The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the
40
+ Trainer if the given time limit for the training loop is reached.
41
+
42
+ Args:
43
+ duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`,
44
+ or a dict containing key-value compatible with :class:`~datetime.timedelta`.
45
+ interval: Determines if the interruption happens on epoch level or mid-epoch.
46
+ Can be either ``"epoch"`` or ``"step"``.
47
+ verbose: Set this to ``False`` to suppress logging messages.
48
+
49
+ Raises:
50
+ MisconfigurationException:
51
+ If ``interval`` is not one of the supported choices.
52
+
53
+ Example::
54
+ from pytorch_lightning import Trainer
55
+ from pytorch_lightning.callbacks import Timer
56
+
57
+ # stop training after 12 hours
58
+ timer = Timer(duration="00:12:00:00")
59
+
60
+ # or provide a datetime.timedelta
61
+ from datetime import timedelta
62
+ timer = Timer(duration=timedelta(weeks=1))
63
+
64
+ # or provide a dictionary
65
+ timer = Timer(duration=dict(weeks=4, days=2))
66
+
67
+ # force training to stop after given time limit
68
+ trainer = Trainer(callbacks=[timer])
69
+
70
+ # query training/validation/test time (in seconds)
71
+ timer.time_elapsed("train")
72
+ timer.start_time("validate")
73
+ timer.end_time("test")
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ duration: Optional[Union[str, timedelta, Dict[str, int]]] = None,
79
+ interval: str = Interval.step,
80
+ verbose: bool = True,
81
+ ) -> None:
82
+ super().__init__()
83
+ if isinstance(duration, str):
84
+ dhms = duration.strip().split(":")
85
+ dhms = [int(i) for i in dhms]
86
+ duration = timedelta(days=dhms[0], hours=dhms[1], minutes=dhms[2], seconds=dhms[3])
87
+ if isinstance(duration, dict):
88
+ duration = timedelta(**duration)
89
+ if interval not in set(Interval):
90
+ raise MisconfigurationException(
91
+ f"Unsupported parameter value `Timer(interval={interval})`. Possible choices are:"
92
+ f" {', '.join(set(Interval))}"
93
+ )
94
+ self._duration = duration.total_seconds() if duration is not None else None
95
+ self._interval = interval
96
+ self._verbose = verbose
97
+ self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
98
+ self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
99
+ self._offset = 0
100
+
101
+ def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
102
+ """Return the start time of a particular stage (in seconds)"""
103
+ stage = RunningStage(stage)
104
+ return self._start_time[stage]
105
+
106
+ def end_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
107
+ """Return the end time of a particular stage (in seconds)"""
108
+ stage = RunningStage(stage)
109
+ return self._end_time[stage]
110
+
111
+ def time_elapsed(self, stage: str = RunningStage.TRAINING) -> float:
112
+ """Return the time elapsed for a particular stage (in seconds)"""
113
+ start = self.start_time(stage)
114
+ end = self.end_time(stage)
115
+ offset = self._offset if stage == RunningStage.TRAINING else 0
116
+ if start is None:
117
+ return offset
118
+ if end is None:
119
+ return time.monotonic() - start + offset
120
+ return end - start + offset
121
+
122
+ def time_remaining(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
123
+ """Return the time remaining for a particular stage (in seconds)"""
124
+ if self._duration is not None:
125
+ return self._duration - self.time_elapsed(stage)
126
+
127
+ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
128
+ self._start_time[RunningStage.TRAINING] = time.monotonic()
129
+
130
+ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
131
+ self._end_time[RunningStage.TRAINING] = time.monotonic()
132
+
133
+ def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
134
+ self._start_time[RunningStage.VALIDATING] = time.monotonic()
135
+
136
+ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
137
+ self._end_time[RunningStage.VALIDATING] = time.monotonic()
138
+
139
+ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
140
+ self._start_time[RunningStage.TESTING] = time.monotonic()
141
+
142
+ def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
143
+ self._end_time[RunningStage.TESTING] = time.monotonic()
144
+
145
+ def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
146
+ # this checks the time after the state is reloaded, regardless of the interval.
147
+ # this is necessary in case we load a state whose timer is already depleted
148
+ if self._duration is None:
149
+ return
150
+ self._check_time_remaining(trainer)
151
+
152
+ def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
153
+ if self._interval != Interval.step or self._duration is None:
154
+ return
155
+ self._check_time_remaining(trainer)
156
+
157
+ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
158
+ if self._interval != Interval.epoch or self._duration is None:
159
+ return
160
+ self._check_time_remaining(trainer)
161
+
162
+ def state_dict(self) -> Dict[str, Any]:
163
+ return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}}
164
+
165
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
166
+ time_elapsed = state_dict.get("time_elapsed", {})
167
+ self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0)
168
+
169
+ def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
170
+ assert self._duration is not None
171
+ should_stop = self.time_elapsed() >= self._duration
172
+ should_stop = trainer.strategy.broadcast(should_stop)
173
+ trainer.should_stop = trainer.should_stop or should_stop
174
+ if should_stop and self._verbose:
175
+ elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))
176
+ rank_zero_info(f"Time limit reached. Elapsed time is {elapsed}. Signaling Trainer to stop.")
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/callbacks/xla_stats_monitor.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ XLA Stats Monitor
16
+ =================
17
+
18
+ Monitor and logs XLA stats during training.
19
+
20
+ """
21
+ import time
22
+
23
+ import pytorch_lightning as pl
24
+ from pytorch_lightning.accelerators import TPUAccelerator
25
+ from pytorch_lightning.callbacks.base import Callback
26
+ from pytorch_lightning.utilities import _TPU_AVAILABLE
27
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
28
+ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info
29
+
30
+ if _TPU_AVAILABLE:
31
+ import torch_xla.core.xla_model as xm
32
+
33
+
34
+ class XLAStatsMonitor(Callback):
35
+ r"""
36
+ .. deprecated:: v1.5
37
+ The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7.
38
+ Please use the `DeviceStatsMonitor` callback instead.
39
+
40
+ Automatically monitors and logs XLA stats during training stage. ``XLAStatsMonitor`` is a callback and in
41
+ order to use it you need to assign a logger in the ``Trainer``.
42
+
43
+ Args:
44
+ verbose: Set to ``True`` to print average peak and free memory, and epoch time
45
+ every epoch.
46
+
47
+ Raises:
48
+ MisconfigurationException:
49
+ If not running on TPUs, or ``Trainer`` has no logger.
50
+
51
+ Example::
52
+
53
+ >>> from pytorch_lightning import Trainer
54
+ >>> from pytorch_lightning.callbacks import XLAStatsMonitor
55
+ >>> xla_stats = XLAStatsMonitor() # doctest: +SKIP
56
+ >>> trainer = Trainer(callbacks=[xla_stats]) # doctest: +SKIP
57
+ """
58
+
59
+ def __init__(self, verbose: bool = True) -> None:
60
+ super().__init__()
61
+
62
+ rank_zero_deprecation(
63
+ "The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7."
64
+ " Please use the `DeviceStatsMonitor` callback instead."
65
+ )
66
+
67
+ if not _TPU_AVAILABLE:
68
+ raise MisconfigurationException("Cannot use XLAStatsMonitor with TPUs are not available")
69
+
70
+ self._verbose = verbose
71
+
72
+ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
73
+ if not trainer.loggers:
74
+ raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
75
+
76
+ if not isinstance(trainer.accelerator, TPUAccelerator):
77
+ raise MisconfigurationException(
78
+ "You are using XLAStatsMonitor but are not running on TPU."
79
+ f" The accelerator is set to {trainer.accelerator.__class__.__name__}."
80
+ )
81
+
82
+ device = trainer.strategy.root_device
83
+ memory_info = xm.get_memory_info(device)
84
+ total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
85
+ rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
86
+
87
+ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
88
+ self._start_time = time.time()
89
+
90
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
91
+ if not trainer.loggers:
92
+ raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
93
+
94
+ device = trainer.strategy.root_device
95
+ memory_info = xm.get_memory_info(device)
96
+ epoch_time = time.time() - self._start_time
97
+
98
+ free_memory = memory_info["kb_free"]
99
+ peak_memory = memory_info["kb_total"] - free_memory
100
+
101
+ free_memory = trainer.strategy.reduce(free_memory) * 0.001
102
+ peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
103
+ epoch_time = trainer.strategy.reduce(epoch_time)
104
+
105
+ for logger in trainer.loggers:
106
+ logger.log_metrics(
107
+ {"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
108
+ step=trainer.current_epoch,
109
+ )
110
+
111
+ if self._verbose:
112
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
113
+ rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB")
114
+ rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """LightningDataModule for loading DataLoaders with ease."""
15
+ from argparse import ArgumentParser, Namespace
16
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
17
+
18
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
19
+
20
+ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
21
+ from pytorch_lightning.core.mixins import HyperparametersMixin
22
+ from pytorch_lightning.utilities import rank_zero_deprecation
23
+ from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
24
+
25
+
26
+ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
27
+ """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main
28
+ advantage is consistent data splits, data preparation and transforms across models.
29
+
30
+ Example::
31
+
32
+ class MyDataModule(LightningDataModule):
33
+ def __init__(self):
34
+ super().__init__()
35
+ def prepare_data(self):
36
+ # download, split, etc...
37
+ # only called on 1 GPU/TPU in distributed
38
+ def setup(self, stage):
39
+ # make assignments here (val/train/test split)
40
+ # called on every process in DDP
41
+ def train_dataloader(self):
42
+ train_split = Dataset(...)
43
+ return DataLoader(train_split)
44
+ def val_dataloader(self):
45
+ val_split = Dataset(...)
46
+ return DataLoader(val_split)
47
+ def test_dataloader(self):
48
+ test_split = Dataset(...)
49
+ return DataLoader(test_split)
50
+ def teardown(self):
51
+ # clean up after fit or test
52
+ # called on every process in DDP
53
+ """
54
+
55
+ name: str = ...
56
+
57
+ def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None):
58
+ super().__init__()
59
+ if train_transforms is not None:
60
+ rank_zero_deprecation(
61
+ "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
62
+ )
63
+ if val_transforms is not None:
64
+ rank_zero_deprecation(
65
+ "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
66
+ )
67
+ if test_transforms is not None:
68
+ rank_zero_deprecation(
69
+ "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
70
+ )
71
+ if dims is not None:
72
+ rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
73
+ self._train_transforms = train_transforms
74
+ self._val_transforms = val_transforms
75
+ self._test_transforms = test_transforms
76
+ self._dims = dims if dims is not None else ()
77
+
78
+ # Pointer to the trainer object
79
+ self.trainer = None
80
+
81
+ @property
82
+ def train_transforms(self):
83
+ """Optional transforms (or collection of transforms) you can apply to train dataset.
84
+
85
+ .. deprecated:: v1.5 Will be removed in v1.7.0.
86
+ """
87
+
88
+ rank_zero_deprecation(
89
+ "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
90
+ )
91
+ return self._train_transforms
92
+
93
+ @train_transforms.setter
94
+ def train_transforms(self, t):
95
+ rank_zero_deprecation(
96
+ "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7."
97
+ )
98
+ self._train_transforms = t
99
+
100
+ @property
101
+ def val_transforms(self):
102
+ """Optional transforms (or collection of transforms) you can apply to validation dataset.
103
+
104
+ .. deprecated:: v1.5 Will be removed in v1.7.0.
105
+ """
106
+
107
+ rank_zero_deprecation(
108
+ "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
109
+ )
110
+ return self._val_transforms
111
+
112
+ @val_transforms.setter
113
+ def val_transforms(self, t):
114
+ rank_zero_deprecation(
115
+ "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7."
116
+ )
117
+ self._val_transforms = t
118
+
119
+ @property
120
+ def test_transforms(self):
121
+ """Optional transforms (or collection of transforms) you can apply to test dataset.
122
+
123
+ .. deprecated:: v1.5 Will be removed in v1.7.0.
124
+ """
125
+
126
+ rank_zero_deprecation(
127
+ "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
128
+ )
129
+ return self._test_transforms
130
+
131
+ @test_transforms.setter
132
+ def test_transforms(self, t):
133
+ rank_zero_deprecation(
134
+ "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."
135
+ )
136
+ self._test_transforms = t
137
+
138
+ @property
139
+ def dims(self):
140
+ """A tuple describing the shape of your data. Extra functionality exposed in ``size``.
141
+
142
+ .. deprecated:: v1.5 Will be removed in v1.7.0.
143
+ """
144
+ rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
145
+ return self._dims
146
+
147
+ @dims.setter
148
+ def dims(self, d):
149
+ rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
150
+ self._dims = d
151
+
152
+ def size(self, dim=None) -> Union[Tuple, List[Tuple]]:
153
+ """Return the dimension of each input either as a tuple or list of tuples. You can index this just as you
154
+ would with a torch tensor.
155
+
156
+ .. deprecated:: v1.5 Will be removed in v1.7.0.
157
+ """
158
+ rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.")
159
+
160
+ if dim is not None:
161
+ return self.dims[dim]
162
+
163
+ return self.dims
164
+
165
+ @classmethod
166
+ def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
167
+ """Extends existing argparse by default `LightningDataModule` attributes."""
168
+ return add_argparse_args(cls, parent_parser, **kwargs)
169
+
170
+ @classmethod
171
+ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
172
+ """Create an instance from CLI arguments.
173
+
174
+ Args:
175
+ args: The parser or namespace to take arguments from. Only known arguments will be
176
+ parsed and passed to the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
177
+ **kwargs: Additional keyword arguments that may override ones in the parser or namespace.
178
+ These must be valid DataModule arguments.
179
+
180
+ Example::
181
+
182
+ parser = ArgumentParser(add_help=False)
183
+ parser = LightningDataModule.add_argparse_args(parser)
184
+ module = LightningDataModule.from_argparse_args(args)
185
+ """
186
+ return from_argparse_args(cls, args, **kwargs)
187
+
188
+ @classmethod
189
+ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
190
+ r"""Scans the DataModule signature and returns argument names, types and default values.
191
+
192
+ Returns:
193
+ List with tuples of 3 values:
194
+ (argument name, set with argument types, argument default value).
195
+ """
196
+ return get_init_arguments_and_types(cls)
197
+
198
+ @classmethod
199
+ def from_datasets(
200
+ cls,
201
+ train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
202
+ val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
203
+ test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
204
+ batch_size: int = 1,
205
+ num_workers: int = 0,
206
+ ):
207
+ r"""
208
+ Create an instance from torch.utils.data.Dataset.
209
+
210
+ Args:
211
+ train_dataset: (optional) Dataset to be used for train_dataloader()
212
+ val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
213
+ test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
214
+ batch_size: Batch size to use for each dataloader. Default is 1.
215
+ num_workers: Number of subprocesses to use for data loading. 0 means that the
216
+ data will be loaded in the main process. Number of CPUs available.
217
+
218
+ """
219
+
220
+ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
221
+ shuffle &= not isinstance(ds, IterableDataset)
222
+ return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
223
+
224
+ def train_dataloader():
225
+ if isinstance(train_dataset, Mapping):
226
+ return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
227
+ if isinstance(train_dataset, Sequence):
228
+ return [dataloader(ds, shuffle=True) for ds in train_dataset]
229
+ return dataloader(train_dataset, shuffle=True)
230
+
231
+ def val_dataloader():
232
+ if isinstance(val_dataset, Sequence):
233
+ return [dataloader(ds) for ds in val_dataset]
234
+ return dataloader(val_dataset)
235
+
236
+ def test_dataloader():
237
+ if isinstance(test_dataset, Sequence):
238
+ return [dataloader(ds) for ds in test_dataset]
239
+ return dataloader(test_dataset)
240
+
241
+ datamodule = cls()
242
+ if train_dataset is not None:
243
+ datamodule.train_dataloader = train_dataloader
244
+ if val_dataset is not None:
245
+ datamodule.val_dataloader = val_dataloader
246
+ if test_dataset is not None:
247
+ datamodule.test_dataloader = test_dataloader
248
+ return datamodule
249
+
250
+ def state_dict(self) -> Dict[str, Any]:
251
+ """Called when saving a checkpoint, implement to generate and save datamodule state.
252
+
253
+ Returns:
254
+ A dictionary containing datamodule state.
255
+ """
256
+ return {}
257
+
258
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
259
+ """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
260
+
261
+ Args:
262
+ state_dict: the datamodule state returned by ``state_dict``.
263
+ """
264
+ pass
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/decorators.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
15
+
16
+ rank_zero_deprecation(
17
+ "Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
18
+ "and will be removed in v1.7. It has been replaced by automatic parameters tying with "
19
+ "`pytorch_lightning.utilities.params_tying.set_shared_parameters`"
20
+ )
21
+
22
+ from functools import wraps # noqa: E402
23
+ from typing import Callable # noqa: E402
24
+
25
+
26
+ def parameter_validation(fn: Callable) -> Callable:
27
+ """Validates that the module parameter lengths match after moving to the device. It is useful when tying
28
+ weights on TPU's.
29
+
30
+ Args:
31
+ fn: ``model_to_device`` method
32
+
33
+ Note:
34
+ TPU's require weights to be tied/shared after moving the module to the device.
35
+ Failure to do this results in the initialization of new weights which are not tied.
36
+ To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook
37
+ which is called after the module has been moved to the device.
38
+
39
+ See Also:
40
+ - `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
41
+ """
42
+
43
+ @wraps(fn)
44
+ def inner_fn(self, *args, **kwargs):
45
+ pre_layer_count = len(list(self.model.parameters()))
46
+ module = fn(self, *args, **kwargs)
47
+ self.model.on_post_move_to_device()
48
+ post_layer_count = len(list(self.model.parameters()))
49
+
50
+ if not pre_layer_count == post_layer_count:
51
+ rank_zero_warn(
52
+ "The model layers do not match after moving to the target device."
53
+ " If your model employs weight sharing on TPU,"
54
+ " please tie your weights using the `on_post_move_to_device` model hook.\n"
55
+ f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]"
56
+ )
57
+
58
+ return module
59
+
60
+ return inner_fn
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/hooks.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Various hooks to be used in the Lightning code."""
15
+
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ import torch
19
+ from torch.optim.optimizer import Optimizer
20
+
21
+ from pytorch_lightning.utilities import move_data_to_device
22
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
23
+ from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
24
+
25
+
26
+ class ModelHooks:
27
+ """Hooks to be used in LightningModule."""
28
+
29
+ def on_fit_start(self) -> None:
30
+ """Called at the very beginning of fit.
31
+
32
+ If on DDP it is called on every process
33
+ """
34
+
35
+ def on_fit_end(self) -> None:
36
+ """Called at the very end of fit.
37
+
38
+ If on DDP it is called on every process
39
+ """
40
+
41
+ def on_train_start(self) -> None:
42
+ """Called at the beginning of training after sanity check."""
43
+
44
+ def on_train_end(self) -> None:
45
+ """Called at the end of training before logger experiment is closed."""
46
+
47
+ def on_validation_start(self) -> None:
48
+ """Called at the beginning of validation."""
49
+
50
+ def on_validation_end(self) -> None:
51
+ """Called at the end of validation."""
52
+
53
+ def on_test_start(self) -> None:
54
+ """Called at the beginning of testing."""
55
+
56
+ def on_test_end(self) -> None:
57
+ """Called at the end of testing."""
58
+
59
+ def on_predict_start(self) -> None:
60
+ """Called at the beginning of predicting."""
61
+
62
+ def on_predict_end(self) -> None:
63
+ """Called at the end of predicting."""
64
+
65
+ def on_pretrain_routine_start(self) -> None:
66
+ """Called at the beginning of the pretrain routine (between fit and train start).
67
+
68
+ - fit
69
+ - pretrain_routine start
70
+ - pretrain_routine end
71
+ - training_start
72
+
73
+ .. deprecated:: v1.6
74
+ :meth:`on_pretrain_routine_start` has been deprecated in v1.6 and will be removed in v1.8.
75
+ Use ``on_fit_start`` instead.
76
+ """
77
+
78
+ def on_pretrain_routine_end(self) -> None:
79
+ """Called at the end of the pretrain routine (between fit and train start).
80
+
81
+ - fit
82
+ - pretrain_routine start
83
+ - pretrain_routine end
84
+ - training_start
85
+
86
+ .. deprecated:: v1.6
87
+ :meth:`on_pretrain_routine_end` has been deprecated in v1.6 and will be removed in v1.8.
88
+ Use ``on_fit_start`` instead.
89
+ """
90
+
91
+ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]:
92
+ """Called in the training loop before anything happens for that batch.
93
+
94
+ If you return -1 here, you will skip training for the rest of the current epoch.
95
+
96
+ Args:
97
+ batch: The batched data as it is returned by the training DataLoader.
98
+ batch_idx: the index of the batch
99
+ unused: Deprecated argument. Will be removed in v1.7.
100
+ """
101
+
102
+ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: int = 0) -> None:
103
+ """Called in the training loop after the batch.
104
+
105
+ Args:
106
+ outputs: The outputs of training_step_end(training_step(x))
107
+ batch: The batched data as it is returned by the training DataLoader.
108
+ batch_idx: the index of the batch
109
+ unused: Deprecated argument. Will be removed in v1.7.
110
+ """
111
+
112
+ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
113
+ """Called in the validation loop before anything happens for that batch.
114
+
115
+ Args:
116
+ batch: The batched data as it is returned by the validation DataLoader.
117
+ batch_idx: the index of the batch
118
+ dataloader_idx: the index of the dataloader
119
+ """
120
+
121
+ def on_validation_batch_end(
122
+ self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
123
+ ) -> None:
124
+ """Called in the validation loop after the batch.
125
+
126
+ Args:
127
+ outputs: The outputs of validation_step_end(validation_step(x))
128
+ batch: The batched data as it is returned by the validation DataLoader.
129
+ batch_idx: the index of the batch
130
+ dataloader_idx: the index of the dataloader
131
+ """
132
+
133
+ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
134
+ """Called in the test loop before anything happens for that batch.
135
+
136
+ Args:
137
+ batch: The batched data as it is returned by the test DataLoader.
138
+ batch_idx: the index of the batch
139
+ dataloader_idx: the index of the dataloader
140
+ """
141
+
142
+ def on_test_batch_end(
143
+ self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
144
+ ) -> None:
145
+ """Called in the test loop after the batch.
146
+
147
+ Args:
148
+ outputs: The outputs of test_step_end(test_step(x))
149
+ batch: The batched data as it is returned by the test DataLoader.
150
+ batch_idx: the index of the batch
151
+ dataloader_idx: the index of the dataloader
152
+ """
153
+
154
+ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
155
+ """Called in the predict loop before anything happens for that batch.
156
+
157
+ Args:
158
+ batch: The batched data as it is returned by the test DataLoader.
159
+ batch_idx: the index of the batch
160
+ dataloader_idx: the index of the dataloader
161
+ """
162
+
163
+ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
164
+ """Called in the predict loop after the batch.
165
+
166
+ Args:
167
+ outputs: The outputs of predict_step_end(test_step(x))
168
+ batch: The batched data as it is returned by the test DataLoader.
169
+ batch_idx: the index of the batch
170
+ dataloader_idx: the index of the dataloader
171
+ """
172
+
173
+ def on_validation_model_eval(self) -> None:
174
+ """Sets the model to eval during the val loop."""
175
+ self.trainer.model.eval()
176
+
177
+ def on_validation_model_train(self) -> None:
178
+ """Sets the model to train during the val loop."""
179
+ self.trainer.model.train()
180
+
181
+ def on_test_model_train(self) -> None:
182
+ """Sets the model to train during the test loop."""
183
+ self.trainer.model.train()
184
+
185
+ def on_test_model_eval(self) -> None:
186
+ """Sets the model to eval during the test loop."""
187
+ self.trainer.model.eval()
188
+
189
+ def on_predict_model_eval(self) -> None:
190
+ """Sets the model to eval during the predict loop."""
191
+ self.trainer.model.eval()
192
+
193
+ def on_epoch_start(self) -> None:
194
+ """Called when either of train/val/test epoch begins.
195
+
196
+ .. deprecated:: v1.6
197
+ :meth:`on_epoch_start` has been deprecated in v1.6 and will be removed in v1.8.
198
+ Use ``on_<train/validation/test>_epoch_start`` instead.
199
+ """
200
+
201
+ def on_epoch_end(self) -> None:
202
+ """Called when either of train/val/test epoch ends.
203
+
204
+ .. deprecated:: v1.6
205
+ :meth:`on_epoch_end` has been deprecated in v1.6 and will be removed in v1.8.
206
+ Use ``on_<train/validation/test>_epoch_end`` instead.
207
+ """
208
+
209
+ def on_train_epoch_start(self) -> None:
210
+ """Called in the training loop at the very beginning of the epoch."""
211
+
212
+ def on_train_epoch_end(self) -> None:
213
+ """Called in the training loop at the very end of the epoch.
214
+
215
+ To access all batch outputs at the end of the epoch, either:
216
+
217
+ 1. Implement `training_epoch_end` in the LightningModule OR
218
+ 2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook
219
+ """
220
+
221
+ def on_validation_epoch_start(self) -> None:
222
+ """Called in the validation loop at the very beginning of the epoch."""
223
+
224
+ def on_validation_epoch_end(self) -> None:
225
+ """Called in the validation loop at the very end of the epoch."""
226
+
227
+ def on_test_epoch_start(self) -> None:
228
+ """Called in the test loop at the very beginning of the epoch."""
229
+
230
+ def on_test_epoch_end(self) -> None:
231
+ """Called in the test loop at the very end of the epoch."""
232
+
233
+ def on_predict_epoch_start(self) -> None:
234
+ """Called at the beginning of predicting."""
235
+
236
+ def on_predict_epoch_end(self, results: List[Any]) -> None:
237
+ """Called at the end of predicting."""
238
+
239
+ def on_before_zero_grad(self, optimizer: Optimizer) -> None:
240
+ """Called after ``training_step()`` and before ``optimizer.zero_grad()``.
241
+
242
+ Called in the training loop after taking an optimizer step and before zeroing grads.
243
+ Good place to inspect weight information with weights updated.
244
+
245
+ This is where it is called::
246
+
247
+ for optimizer in optimizers:
248
+ out = training_step(...)
249
+
250
+ model.on_before_zero_grad(optimizer) # < ---- called here
251
+ optimizer.zero_grad()
252
+
253
+ backward()
254
+
255
+ Args:
256
+ optimizer: The optimizer for which grads should be zeroed.
257
+ """
258
+
259
+ def on_before_backward(self, loss: torch.Tensor) -> None:
260
+ """Called before ``loss.backward()``.
261
+
262
+ Args:
263
+ loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP.
264
+ """
265
+ pass
266
+
267
+ def on_after_backward(self) -> None:
268
+ """Called after ``loss.backward()`` and before optimizers are stepped.
269
+
270
+ Note:
271
+ If using native AMP, the gradients will not be unscaled at this point.
272
+ Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
273
+ """
274
+
275
+ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
276
+ """Called before ``optimizer.step()``.
277
+
278
+ If using gradient accumulation, the hook is called once the gradients have been accumulated.
279
+ See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
280
+
281
+ If using native AMP, the loss will be unscaled before calling this hook.
282
+ See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
283
+ for more information on the scaling of gradients.
284
+
285
+ If clipping gradients, the gradients will not have been clipped yet.
286
+
287
+ Args:
288
+ optimizer: Current optimizer being used.
289
+ optimizer_idx: Index of the current optimizer being used.
290
+
291
+ Example::
292
+
293
+ def on_before_optimizer_step(self, optimizer, optimizer_idx):
294
+ # example to inspect gradient information in tensorboard
295
+ if self.trainer.global_step % 25 == 0: # don't make the tf file huge
296
+ for k, v in self.named_parameters():
297
+ self.logger.experiment.add_histogram(
298
+ tag=k, values=v.grad, global_step=self.trainer.global_step
299
+ )
300
+ """
301
+
302
+ def on_post_move_to_device(self) -> None:
303
+ """Called in the ``parameter_validation`` decorator after
304
+ :meth:`~pytorch_lightning.core.LightningModule.to` is called. This is a good place to tie weights between
305
+ modules after moving them to a device. Can be used when training models with weight sharing properties on
306
+ TPU.
307
+
308
+ Addresses the handling of shared weights on TPU:
309
+ https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks
310
+
311
+ Example::
312
+
313
+ def on_post_move_to_device(self):
314
+ self.decoder.weight = self.encoder.weight
315
+ """
316
+
317
+ def configure_sharded_model(self) -> None:
318
+ """Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
319
+ where we'd like to shard the model instantly, which is useful for extremely large models which can save
320
+ memory and initialization time.
321
+
322
+ This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
323
+ implementation of this hook is idempotent.
324
+ """
325
+
326
+
327
+ class DataHooks:
328
+ """Hooks to be used for data related stuff."""
329
+
330
+ def __init__(self) -> None:
331
+ """
332
+ Attributes:
333
+ prepare_data_per_node:
334
+ If True, each LOCAL_RANK=0 will call prepare data.
335
+ Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
336
+ allow_zero_length_dataloader_with_multiple_devices:
337
+ If True, dataloader with zero length within local rank is allowed.
338
+ Default value is False.
339
+ """
340
+ super().__init__()
341
+ self.prepare_data_per_node: bool = True
342
+ self.allow_zero_length_dataloader_with_multiple_devices: bool = False
343
+
344
+ def prepare_data(self) -> None:
345
+ """Use this to download and prepare data. Downloading and saving data with multiple processes (distributed
346
+ settings) will result in corrupted data. Lightning ensures this method is called only within a single
347
+ process, so you can safely add your downloading logic within.
348
+
349
+ .. warning:: DO NOT set state to the model (use ``setup`` instead)
350
+ since this is NOT called on every device
351
+
352
+ Example::
353
+
354
+ def prepare_data(self):
355
+ # good
356
+ download_data()
357
+ tokenize()
358
+ etc()
359
+
360
+ # bad
361
+ self.split = data_split
362
+ self.some_state = some_other_state()
363
+
364
+ In DDP ``prepare_data`` can be called in two ways (using Trainer(prepare_data_per_node)):
365
+
366
+ 1. Once per node. This is the default and is only called on LOCAL_RANK=0.
367
+ 2. Once in total. Only called on GLOBAL_RANK=0.
368
+
369
+ See :ref:`prepare_data_per_node<common/lightning_module:prepare_data_per_node>`.
370
+
371
+ Example::
372
+
373
+ # DEFAULT
374
+ # called once per node on LOCAL_RANK=0 of that node
375
+ Trainer(prepare_data_per_node=True)
376
+
377
+ # call on GLOBAL_RANK=0 (great for shared file systems)
378
+ Trainer(prepare_data_per_node=False)
379
+
380
+ This is called before requesting the dataloaders:
381
+
382
+ .. code-block:: python
383
+
384
+ model.prepare_data()
385
+ initialize_distributed()
386
+ model.setup(stage)
387
+ model.train_dataloader()
388
+ model.val_dataloader()
389
+ model.test_dataloader()
390
+ """
391
+
392
+ def setup(self, stage: Optional[str] = None) -> None:
393
+ """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when
394
+ you need to build models dynamically or adjust something about them. This hook is called on every process
395
+ when using DDP.
396
+
397
+ Args:
398
+ stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
399
+
400
+ Example::
401
+
402
+ class LitModel(...):
403
+ def __init__(self):
404
+ self.l1 = None
405
+
406
+ def prepare_data(self):
407
+ download_data()
408
+ tokenize()
409
+
410
+ # don't do this
411
+ self.something = else
412
+
413
+ def setup(self, stage):
414
+ data = load_data(...)
415
+ self.l1 = nn.Linear(28, data.num_classes)
416
+ """
417
+
418
+ def teardown(self, stage: Optional[str] = None) -> None:
419
+ """Called at the end of fit (train + validate), validate, test, or predict.
420
+
421
+ Args:
422
+ stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
423
+ """
424
+
425
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
426
+ """Implement one or more PyTorch DataLoaders for training.
427
+
428
+ Return:
429
+ A collection of :class:`torch.utils.data.DataLoader` specifying training samples.
430
+ In the case of multiple dataloaders, please see this :ref:`section <multiple-dataloaders>`.
431
+
432
+ The dataloader you return will not be reloaded unless you set
433
+ :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
434
+ a positive integer.
435
+
436
+ For data processing use the following pattern:
437
+
438
+ - download in :meth:`prepare_data`
439
+ - process and split in :meth:`setup`
440
+
441
+ However, the above are only necessary for distributed processing.
442
+
443
+ .. warning:: do not assign state in prepare_data
444
+
445
+ - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
446
+ - :meth:`prepare_data`
447
+ - :meth:`setup`
448
+
449
+ Note:
450
+ Lightning adds the correct sampler for distributed and arbitrary hardware.
451
+ There is no need to set it yourself.
452
+
453
+ Example::
454
+
455
+ # single dataloader
456
+ def train_dataloader(self):
457
+ transform = transforms.Compose([transforms.ToTensor(),
458
+ transforms.Normalize((0.5,), (1.0,))])
459
+ dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
460
+ download=True)
461
+ loader = torch.utils.data.DataLoader(
462
+ dataset=dataset,
463
+ batch_size=self.batch_size,
464
+ shuffle=True
465
+ )
466
+ return loader
467
+
468
+ # multiple dataloaders, return as list
469
+ def train_dataloader(self):
470
+ mnist = MNIST(...)
471
+ cifar = CIFAR(...)
472
+ mnist_loader = torch.utils.data.DataLoader(
473
+ dataset=mnist, batch_size=self.batch_size, shuffle=True
474
+ )
475
+ cifar_loader = torch.utils.data.DataLoader(
476
+ dataset=cifar, batch_size=self.batch_size, shuffle=True
477
+ )
478
+ # each batch will be a list of tensors: [batch_mnist, batch_cifar]
479
+ return [mnist_loader, cifar_loader]
480
+
481
+ # multiple dataloader, return as dict
482
+ def train_dataloader(self):
483
+ mnist = MNIST(...)
484
+ cifar = CIFAR(...)
485
+ mnist_loader = torch.utils.data.DataLoader(
486
+ dataset=mnist, batch_size=self.batch_size, shuffle=True
487
+ )
488
+ cifar_loader = torch.utils.data.DataLoader(
489
+ dataset=cifar, batch_size=self.batch_size, shuffle=True
490
+ )
491
+ # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
492
+ return {'mnist': mnist_loader, 'cifar': cifar_loader}
493
+ """
494
+ raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
495
+
496
+ def test_dataloader(self) -> EVAL_DATALOADERS:
497
+ r"""
498
+ Implement one or multiple PyTorch DataLoaders for testing.
499
+
500
+ For data processing use the following pattern:
501
+
502
+ - download in :meth:`prepare_data`
503
+ - process and split in :meth:`setup`
504
+
505
+ However, the above are only necessary for distributed processing.
506
+
507
+ .. warning:: do not assign state in prepare_data
508
+
509
+
510
+ - :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`
511
+ - :meth:`prepare_data`
512
+ - :meth:`setup`
513
+
514
+ Note:
515
+ Lightning adds the correct sampler for distributed and arbitrary hardware.
516
+ There is no need to set it yourself.
517
+
518
+ Return:
519
+ A :class:`torch.utils.data.DataLoader` or a sequence of them specifying testing samples.
520
+
521
+ Example::
522
+
523
+ def test_dataloader(self):
524
+ transform = transforms.Compose([transforms.ToTensor(),
525
+ transforms.Normalize((0.5,), (1.0,))])
526
+ dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
527
+ download=True)
528
+ loader = torch.utils.data.DataLoader(
529
+ dataset=dataset,
530
+ batch_size=self.batch_size,
531
+ shuffle=False
532
+ )
533
+
534
+ return loader
535
+
536
+ # can also return multiple dataloaders
537
+ def test_dataloader(self):
538
+ return [loader_a, loader_b, ..., loader_n]
539
+
540
+ Note:
541
+ If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
542
+ this method.
543
+
544
+ Note:
545
+ In the case where you return multiple test dataloaders, the :meth:`test_step`
546
+ will have an argument ``dataloader_idx`` which matches the order here.
547
+ """
548
+ raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer")
549
+
550
+ def val_dataloader(self) -> EVAL_DATALOADERS:
551
+ r"""
552
+ Implement one or multiple PyTorch DataLoaders for validation.
553
+
554
+ The dataloader you return will not be reloaded unless you set
555
+ :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
556
+ a positive integer.
557
+
558
+ It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
559
+
560
+ - :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`
561
+ - :meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`
562
+ - :meth:`prepare_data`
563
+ - :meth:`setup`
564
+
565
+ Note:
566
+ Lightning adds the correct sampler for distributed and arbitrary hardware
567
+ There is no need to set it yourself.
568
+
569
+ Return:
570
+ A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
571
+
572
+ Examples::
573
+
574
+ def val_dataloader(self):
575
+ transform = transforms.Compose([transforms.ToTensor(),
576
+ transforms.Normalize((0.5,), (1.0,))])
577
+ dataset = MNIST(root='/path/to/mnist/', train=False,
578
+ transform=transform, download=True)
579
+ loader = torch.utils.data.DataLoader(
580
+ dataset=dataset,
581
+ batch_size=self.batch_size,
582
+ shuffle=False
583
+ )
584
+
585
+ return loader
586
+
587
+ # can also return multiple dataloaders
588
+ def val_dataloader(self):
589
+ return [loader_a, loader_b, ..., loader_n]
590
+
591
+ Note:
592
+ If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
593
+ implement this method.
594
+
595
+ Note:
596
+ In the case where you return multiple validation dataloaders, the :meth:`validation_step`
597
+ will have an argument ``dataloader_idx`` which matches the order here.
598
+ """
599
+ raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer")
600
+
601
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
602
+ r"""
603
+ Implement one or multiple PyTorch DataLoaders for prediction.
604
+
605
+ It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
606
+
607
+ - :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`
608
+ - :meth:`prepare_data`
609
+ - :meth:`setup`
610
+
611
+ Note:
612
+ Lightning adds the correct sampler for distributed and arbitrary hardware
613
+ There is no need to set it yourself.
614
+
615
+ Return:
616
+ A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.
617
+
618
+ Note:
619
+ In the case where you return multiple prediction dataloaders, the :meth:`predict_step`
620
+ will have an argument ``dataloader_idx`` which matches the order here.
621
+ """
622
+ raise MisconfigurationException(
623
+ "`predict_dataloader` must be implemented to be used with the Lightning Trainer"
624
+ )
625
+
626
+ def on_train_dataloader(self) -> None:
627
+ """Called before requesting the train dataloader.
628
+
629
+ .. deprecated:: v1.5
630
+ :meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
631
+ Please use :meth:`train_dataloader()` directly.
632
+ """
633
+
634
+ def on_val_dataloader(self) -> None:
635
+ """Called before requesting the val dataloader.
636
+
637
+ .. deprecated:: v1.5
638
+ :meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
639
+ Please use :meth:`val_dataloader()` directly.
640
+ """
641
+
642
+ def on_test_dataloader(self) -> None:
643
+ """Called before requesting the test dataloader.
644
+
645
+ .. deprecated:: v1.5
646
+ :meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
647
+ Please use :meth:`test_dataloader()` directly.
648
+ """
649
+
650
+ def on_predict_dataloader(self) -> None:
651
+ """Called before requesting the predict dataloader.
652
+
653
+ .. deprecated:: v1.5
654
+ :meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
655
+ Please use :meth:`predict_dataloader()` directly.
656
+ """
657
+
658
+ def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
659
+ """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom
660
+ data structure.
661
+
662
+ The data types listed below (and any arbitrary nesting of them) are supported out of the box:
663
+
664
+ - :class:`torch.Tensor` or anything that implements `.to(...)`
665
+ - :class:`list`
666
+ - :class:`dict`
667
+ - :class:`tuple`
668
+ - :class:`torchtext.data.batch.Batch`
669
+
670
+ For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
671
+
672
+ Note:
673
+ This hook should only transfer the data and not modify it, nor should it move the data to
674
+ any other device than the one passed in as argument (unless you know what you are doing).
675
+ To check the current state of execution of this hook you can use
676
+ ``self.trainer.training/testing/validating/predicting`` so that you can
677
+ add different logic as per your requirement.
678
+
679
+ Note:
680
+ This hook only runs on single GPU training and DDP (no data-parallel).
681
+ Data-Parallel support will come in near future.
682
+
683
+ Args:
684
+ batch: A batch of data that needs to be transferred to a new device.
685
+ device: The target device as defined in PyTorch.
686
+ dataloader_idx: The index of the dataloader to which the batch belongs.
687
+
688
+ Returns:
689
+ A reference to the data on the new device.
690
+
691
+ Example::
692
+
693
+ def transfer_batch_to_device(self, batch, device, dataloader_idx):
694
+ if isinstance(batch, CustomBatch):
695
+ # move all tensors in your custom data structure to the device
696
+ batch.samples = batch.samples.to(device)
697
+ batch.targets = batch.targets.to(device)
698
+ elif dataloader_idx == 0:
699
+ # skip device transfer for the first dataloader or anything you wish
700
+ pass
701
+ else:
702
+ batch = super().transfer_batch_to_device(data, device, dataloader_idx)
703
+ return batch
704
+
705
+ Raises:
706
+ MisconfigurationException:
707
+ If using data-parallel, ``Trainer(strategy='dp')``.
708
+
709
+ See Also:
710
+ - :meth:`move_data_to_device`
711
+ - :meth:`apply_to_collection`
712
+ """
713
+ return move_data_to_device(batch, device)
714
+
715
+ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
716
+ """Override to alter or apply batch augmentations to your batch before it is transferred to the device.
717
+
718
+ Note:
719
+ To check the current state of execution of this hook you can use
720
+ ``self.trainer.training/testing/validating/predicting`` so that you can
721
+ add different logic as per your requirement.
722
+
723
+ Note:
724
+ This hook only runs on single GPU training and DDP (no data-parallel).
725
+ Data-Parallel support will come in near future.
726
+
727
+ Args:
728
+ batch: A batch of data that needs to be altered or augmented.
729
+ dataloader_idx: The index of the dataloader to which the batch belongs.
730
+
731
+ Returns:
732
+ A batch of data
733
+
734
+ Example::
735
+
736
+ def on_before_batch_transfer(self, batch, dataloader_idx):
737
+ batch['x'] = transforms(batch['x'])
738
+ return batch
739
+
740
+ Raises:
741
+ MisconfigurationException:
742
+ If using data-parallel, ``Trainer(strategy='dp')``.
743
+
744
+ See Also:
745
+ - :meth:`on_after_batch_transfer`
746
+ - :meth:`transfer_batch_to_device`
747
+ """
748
+ return batch
749
+
750
+ def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
751
+ """Override to alter or apply batch augmentations to your batch after it is transferred to the device.
752
+
753
+ Note:
754
+ To check the current state of execution of this hook you can use
755
+ ``self.trainer.training/testing/validating/predicting`` so that you can
756
+ add different logic as per your requirement.
757
+
758
+ Note:
759
+ This hook only runs on single GPU training and DDP (no data-parallel).
760
+ Data-Parallel support will come in near future.
761
+
762
+ Args:
763
+ batch: A batch of data that needs to be altered or augmented.
764
+ dataloader_idx: The index of the dataloader to which the batch belongs.
765
+
766
+ Returns:
767
+ A batch of data
768
+
769
+ Example::
770
+
771
+ def on_after_batch_transfer(self, batch, dataloader_idx):
772
+ batch['x'] = gpu_transforms(batch['x'])
773
+ return batch
774
+
775
+ Raises:
776
+ MisconfigurationException:
777
+ If using data-parallel, ``Trainer(strategy='dp')``.
778
+
779
+ See Also:
780
+ - :meth:`on_before_batch_transfer`
781
+ - :meth:`transfer_batch_to_device`
782
+ """
783
+ return batch
784
+
785
+
786
+ class CheckpointHooks:
787
+ """Hooks to be used with Checkpointing."""
788
+
789
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
790
+ r"""
791
+ Called by Lightning to restore your model.
792
+ If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this.
793
+
794
+ Args:
795
+ checkpoint: Loaded checkpoint
796
+
797
+ Example::
798
+
799
+ def on_load_checkpoint(self, checkpoint):
800
+ # 99% of the time you don't need to implement this method
801
+ self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
802
+
803
+ Note:
804
+ Lightning auto-restores global step, epoch, and train state including amp scaling.
805
+ There is no need for you to restore anything regarding training.
806
+ """
807
+
808
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
809
+ r"""
810
+ Called by Lightning when saving a checkpoint to give you a chance to store anything
811
+ else you might want to save.
812
+
813
+ Args:
814
+ checkpoint: The full checkpoint dictionary before it gets dumped to a file.
815
+ Implementations of this hook can insert additional data into this dictionary.
816
+
817
+ Example::
818
+
819
+ def on_save_checkpoint(self, checkpoint):
820
+ # 99% of use cases you don't need to implement this method
821
+ checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
822
+
823
+ Note:
824
+ Lightning saves all aspects of training (epoch, global step, etc...)
825
+ including amp scaling.
826
+ There is no need for you to store anything about training.
827
+
828
+ """
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from contextlib import contextmanager
15
+ from dataclasses import fields
16
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
17
+ from weakref import proxy
18
+
19
+ import torch
20
+ from torch import optim
21
+ from torch.optim import Optimizer
22
+
23
+ import pytorch_lightning as pl
24
+ from pytorch_lightning.utilities.exceptions import MisconfigurationException
25
+ from pytorch_lightning.utilities.model_helpers import is_overridden
26
+ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
27
+ from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau
28
+
29
+
30
+ def do_nothing_closure() -> None:
31
+ return
32
+
33
+
34
+ class LightningOptimizer:
35
+ """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic
36
+ across accelerators, AMP, accumulate_grad_batches."""
37
+
38
+ def __init__(self, optimizer: Optimizer):
39
+ # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has
40
+ # implemented custom logic which we would not want to call on destruction of the `LightningOptimizer`
41
+ self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
42
+
43
+ # For Horovod
44
+ if hasattr(optimizer, "skip_synchronize"):
45
+ self.__class__ = type(
46
+ "Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {}
47
+ )
48
+ self.skip_synchronize = optimizer.skip_synchronize
49
+ self.synchronize = optimizer.synchronize
50
+ else:
51
+ self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
52
+
53
+ self._optimizer = optimizer
54
+ self._strategy: Optional[pl.strategies.Strategy] = None
55
+ self._optimizer_idx = 0
56
+ # to inject logic around the optimizer step, particularly useful with manual optimization
57
+ self._on_before_step = do_nothing_closure
58
+ self._on_after_step = do_nothing_closure
59
+
60
+ @property
61
+ def optimizer(self) -> Optimizer:
62
+ return self._optimizer
63
+
64
+ @classmethod
65
+ def _to_lightning_optimizer(
66
+ cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy", opt_idx: int
67
+ ) -> "LightningOptimizer":
68
+ if isinstance(optimizer, LightningOptimizer):
69
+ # the user could return a `LightningOptimizer` from `configure_optimizers`, see test:
70
+ # tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]
71
+ lightning_optimizer = optimizer
72
+ else:
73
+ lightning_optimizer = cls(optimizer)
74
+ lightning_optimizer._strategy = proxy(strategy)
75
+ lightning_optimizer._optimizer_idx = opt_idx
76
+ return lightning_optimizer
77
+
78
+ @contextmanager
79
+ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
80
+ """This function is just a helper for advanced users.
81
+
82
+ Considering the current optimizer as A and all other optimizers as B.
83
+ Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False.
84
+
85
+ When performing gradient accumulation, there is no need to perform grad synchronization
86
+ during the accumulation phase.
87
+ Setting `sync_grad` to False will block this synchronization and improve performance.
88
+ """
89
+ # local import here to avoid circular import
90
+ from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior
91
+
92
+ assert self._strategy is not None
93
+ lightning_module = self._strategy.lightning_module
94
+ assert lightning_module is not None
95
+ with _block_parallel_sync_behavior(self._strategy, block=(not sync_grad)):
96
+ lightning_module.toggle_optimizer(self, self._optimizer_idx)
97
+ yield
98
+ lightning_module.untoggle_optimizer(self._optimizer_idx)
99
+
100
+ def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any:
101
+ """Performs a single optimization step (parameter update).
102
+
103
+ Args:
104
+ closure: An optional optimizer closure.
105
+ kwargs: Any additional arguments to the ``optimizer.step()`` call.
106
+
107
+ Returns:
108
+ The output from the step call, which is generally the output of the closure execution.
109
+
110
+ Example::
111
+
112
+ # Scenario for a GAN using manual optimization
113
+ def training_step(...):
114
+ opt_gen, opt_dis = self.optimizers()
115
+
116
+ ...
117
+
118
+ # compute generator loss
119
+ loss_gen = self.compute_generator_loss(...)
120
+ # zero_grad needs to be called before backward
121
+ opt_gen.zero_grad()
122
+ self.manual_backward(loss_gen)
123
+ opt_gen.step()
124
+
125
+ # compute discriminator loss
126
+ loss_dis = self.compute_discriminator_loss(...)
127
+
128
+ # zero_grad needs to be called before backward
129
+ opt_dis.zero_grad()
130
+ self.manual_backward(loss_dis)
131
+ opt_dis.step()
132
+
133
+
134
+ # A more advanced example
135
+ def training_step(self, batch, batch_idx, ...):
136
+ opt_gen, opt_dis = self.optimizers()
137
+
138
+ ...
139
+ accumulated_grad_batches = batch_idx % 2 == 0
140
+
141
+ # compute generator loss
142
+ def closure_gen():
143
+ loss_gen = self.compute_generator_loss(...)
144
+ self.manual_backward(loss_gen)
145
+ if accumulated_grad_batches:
146
+ opt_gen.zero_grad()
147
+
148
+ with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
149
+ opt_gen.step(closure=closure_gen)
150
+
151
+ def closure_dis():
152
+ loss_dis = self.compute_discriminator_loss(...)
153
+ self.manual_backward(loss_dis)
154
+ if accumulated_grad_batches:
155
+ opt_dis.zero_grad()
156
+
157
+ with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
158
+ opt_dis.step(closure=closure_dis)
159
+ """
160
+ self._on_before_step()
161
+
162
+ if closure is None:
163
+ closure = do_nothing_closure
164
+ elif not callable(closure):
165
+ raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
166
+
167
+ assert self._strategy is not None
168
+ step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
169
+
170
+ self._on_after_step()
171
+
172
+ return step_output
173
+
174
+
175
+ def _init_optimizers_and_lr_schedulers(
176
+ model: "pl.LightningModule",
177
+ ) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
178
+ """Calls `LightningModule.configure_optimizers` and parses and validates the output."""
179
+ assert model.trainer is not None
180
+ optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
181
+
182
+ if optim_conf is None:
183
+ rank_zero_warn(
184
+ "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
185
+ )
186
+ optim_conf = _MockOptimizer()
187
+
188
+ optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
189
+ lr_scheduler_configs = (
190
+ _configure_schedulers_automatic_opt(lr_schedulers, monitor)
191
+ if model.automatic_optimization
192
+ else _configure_schedulers_manual_opt(lr_schedulers)
193
+ )
194
+ _set_scheduler_opt_idx(optimizers, lr_scheduler_configs)
195
+ _validate_scheduler_api(lr_scheduler_configs, model)
196
+ return optimizers, lr_scheduler_configs, optimizer_frequencies
197
+
198
+
199
+ def _configure_optimizers(
200
+ optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
201
+ ) -> Tuple[List, List, List, Optional[str]]:
202
+ optimizers, lr_schedulers, optimizer_frequencies = [], [], []
203
+ monitor = None
204
+
205
+ # single output, single optimizer
206
+ if isinstance(optim_conf, Optimizer):
207
+ optimizers = [optim_conf]
208
+ # two lists, optimizer + lr schedulers
209
+ elif (
210
+ isinstance(optim_conf, (list, tuple))
211
+ and len(optim_conf) == 2
212
+ and isinstance(optim_conf[0], list)
213
+ and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
214
+ ):
215
+ opt, sch = optim_conf
216
+ optimizers = opt
217
+ lr_schedulers = sch if isinstance(sch, list) else [sch]
218
+ # single dictionary
219
+ elif isinstance(optim_conf, dict):
220
+ _validate_optim_conf(optim_conf)
221
+ optimizers = [optim_conf["optimizer"]]
222
+ monitor = optim_conf.get("monitor", None)
223
+ lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
224
+ # multiple dictionaries
225
+ elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
226
+ for opt_dict in optim_conf:
227
+ _validate_optim_conf(opt_dict)
228
+ optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
229
+ scheduler_dict = (
230
+ lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
231
+ if isinstance(scheduler, dict)
232
+ else {"scheduler": scheduler, "opt_idx": opt_idx}
233
+ )
234
+
235
+ lr_schedulers = [
236
+ scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
237
+ for opt_idx, opt_dict in enumerate(optim_conf)
238
+ if "lr_scheduler" in opt_dict
239
+ ]
240
+ optimizer_frequencies = [
241
+ opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
242
+ ]
243
+ # assert that if frequencies are present, they are given for all optimizers
244
+ if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
245
+ raise ValueError("A frequency must be given to each optimizer.")
246
+ # single list or tuple, multiple optimizer
247
+ elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
248
+ optimizers = list(optim_conf)
249
+ # unknown configuration
250
+ else:
251
+ raise MisconfigurationException(
252
+ "Unknown configuration for model optimizers."
253
+ " Output from `model.configure_optimizers()` should be one of:\n"
254
+ " * `Optimizer`\n"
255
+ " * [`Optimizer`]\n"
256
+ " * ([`Optimizer`], [`_LRScheduler`])\n"
257
+ ' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
258
+ ' * A list of the previously described dict format, with an optional "frequency" key (int)'
259
+ )
260
+ return optimizers, lr_schedulers, optimizer_frequencies, monitor
261
+
262
+
263
+ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
264
+ """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic
265
+ optimization."""
266
+ lr_scheduler_configs = []
267
+ for scheduler in schedulers:
268
+ if isinstance(scheduler, dict):
269
+ # check provided keys
270
+ supported_keys = {field.name for field in fields(LRSchedulerConfig)}
271
+ extra_keys = scheduler.keys() - supported_keys
272
+ if extra_keys:
273
+ rank_zero_warn(
274
+ f"Found unsupported keys in the lr scheduler dict: {extra_keys}."
275
+ " HINT: remove them from the output of `configure_optimizers`.",
276
+ category=RuntimeWarning,
277
+ )
278
+ scheduler = {k: v for k, v in scheduler.items() if k in supported_keys}
279
+ if "scheduler" not in scheduler:
280
+ raise MisconfigurationException(
281
+ 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
282
+ )
283
+ if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
284
+ raise MisconfigurationException(
285
+ 'The "interval" key in lr scheduler dict must be "step" or "epoch"'
286
+ f' but is "{scheduler["interval"]}"'
287
+ )
288
+ scheduler["reduce_on_plateau"] = isinstance(scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau)
289
+ if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
290
+ raise MisconfigurationException(
291
+ "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
292
+ ' For example: {"optimizer": optimizer, "lr_scheduler":'
293
+ ' {"scheduler": scheduler, "monitor": "your_loss"}}'
294
+ )
295
+ is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
296
+ if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
297
+ rank_zero_warn(
298
+ "A `OneCycleLR` scheduler is using 'interval': 'epoch'."
299
+ " Are you sure you didn't mean 'interval': 'step'?",
300
+ category=RuntimeWarning,
301
+ )
302
+ config = LRSchedulerConfig(**scheduler)
303
+ elif isinstance(scheduler, ReduceLROnPlateau):
304
+ if monitor is None:
305
+ raise MisconfigurationException(
306
+ "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
307
+ " scheduler is used. For example:"
308
+ ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
309
+ )
310
+ config = LRSchedulerConfig(scheduler, reduce_on_plateau=True, monitor=monitor)
311
+ else:
312
+ config = LRSchedulerConfig(scheduler)
313
+ lr_scheduler_configs.append(config)
314
+ return lr_scheduler_configs
315
+
316
+
317
+ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]:
318
+ """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual
319
+ optimization."""
320
+ lr_scheduler_configs = []
321
+ for scheduler in schedulers:
322
+ if isinstance(scheduler, dict):
323
+ invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
324
+ keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]
325
+
326
+ if keys_to_warn:
327
+ rank_zero_warn(
328
+ f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
329
+ " You need to call `lr_scheduler.step()` manually in manual optimization.",
330
+ category=RuntimeWarning,
331
+ )
332
+
333
+ config = LRSchedulerConfig(**{key: scheduler[key] for key in scheduler if key not in invalid_keys})
334
+ else:
335
+ config = LRSchedulerConfig(scheduler)
336
+ lr_scheduler_configs.append(config)
337
+ return lr_scheduler_configs
338
+
339
+
340
+ def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None:
341
+ for config in lr_scheduler_configs:
342
+ scheduler = config.scheduler
343
+ if not isinstance(scheduler, _Stateful):
344
+ raise TypeError(
345
+ f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."
346
+ " It should have `state_dict` and `load_state_dict` methods defined."
347
+ )
348
+
349
+ if not isinstance(scheduler, LRSchedulerTypeTuple) and not is_overridden("lr_scheduler_step", model):
350
+ raise MisconfigurationException(
351
+ f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"
352
+ " API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"
353
+ " you are using a custom LR scheduler."
354
+ )
355
+
356
+
357
+ def _set_scheduler_opt_idx(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None:
358
+ for config in lr_scheduler_configs:
359
+
360
+ for opt_idx, opt in enumerate(optimizers):
361
+ if config.scheduler.optimizer is opt:
362
+ if config.opt_idx is not None and config.opt_idx != opt_idx:
363
+ raise MisconfigurationException(
364
+ "`opt_idx` set inside scheduler config does not match with the index"
365
+ " of the respective optimizer returned from `configure_optimizers`."
366
+ )
367
+
368
+ config.opt_idx = opt_idx
369
+ break
370
+ else:
371
+ raise MisconfigurationException(
372
+ "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
373
+ )
374
+
375
+
376
+ def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
377
+ valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
378
+ extra_keys = optim_conf.keys() - valid_keys
379
+ if extra_keys:
380
+ rank_zero_warn(
381
+ f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
382
+ )
383
+
384
+
385
+ class _MockOptimizer(Optimizer):
386
+ """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
387
+ `configure_optimizers`."""
388
+
389
+ def __init__(self) -> None:
390
+ super().__init__([torch.zeros(1)], {})
391
+
392
+ def add_param_group(self, param_group: Dict[Any, Any]) -> None:
393
+ pass # Do Nothing
394
+
395
+ def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
396
+ pass # Do Nothing
397
+
398
+ def state_dict(self) -> Dict[str, Any]:
399
+ return {} # Return Empty
400
+
401
+ def step(self, closure: Callable = None) -> None:
402
+ if closure is not None:
403
+ closure()
404
+
405
+ def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
406
+ pass # Do Nothing
407
+
408
+ def __repr__(self) -> str:
409
+ return "No Optimizer"
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/core/saving.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ast
16
+ import csv
17
+ import inspect
18
+ import logging
19
+ import os
20
+ from argparse import Namespace
21
+ from copy import deepcopy
22
+ from enum import Enum
23
+ from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
24
+ from warnings import warn
25
+
26
+ import torch
27
+ import yaml
28
+
29
+ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
30
+ from pytorch_lightning.utilities.apply_func import apply_to_collection
31
+ from pytorch_lightning.utilities.cloud_io import get_filesystem
32
+ from pytorch_lightning.utilities.cloud_io import load as pl_load
33
+ from pytorch_lightning.utilities.migration import pl_legacy_patch
34
+ from pytorch_lightning.utilities.parsing import parse_class_init_keys
35
+ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
36
+
37
+ log = logging.getLogger(__name__)
38
+ PRIMITIVE_TYPES = (bool, int, float, str)
39
+ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
40
+
41
+ if _OMEGACONF_AVAILABLE:
42
+ from omegaconf import OmegaConf
43
+ from omegaconf.dictconfig import DictConfig
44
+ from omegaconf.errors import UnsupportedValueType, ValidationError
45
+
46
+ # the older shall be on the top
47
+ CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6
48
+
49
+
50
+ class ModelIO:
51
+ CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
52
+ CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
53
+ CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
54
+
55
+ @classmethod
56
+ def load_from_checkpoint(
57
+ cls,
58
+ checkpoint_path: Union[str, IO],
59
+ map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
60
+ hparams_file: Optional[str] = None,
61
+ strict: bool = True,
62
+ **kwargs,
63
+ ):
64
+ r"""
65
+ Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
66
+ it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
67
+
68
+ Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
69
+
70
+ Args:
71
+ checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
72
+ map_location:
73
+ If your checkpoint saved a GPU model and you now load on CPUs
74
+ or a different number of GPUs, use this to map to the new setup.
75
+ The behaviour is the same as in :func:`torch.load`.
76
+ hparams_file: Optional path to a .yaml file with hierarchical structure
77
+ as in this example::
78
+
79
+ drop_prob: 0.2
80
+ dataloader:
81
+ batch_size: 32
82
+
83
+ You most likely won't need this since Lightning will always save the hyperparameters
84
+ to the checkpoint.
85
+ However, if your checkpoint weights don't have the hyperparameters saved,
86
+ use this method to pass in a .yaml file with the hparams you'd like to use.
87
+ These will be converted into a :class:`~dict` and passed into your
88
+ :class:`LightningModule` for use.
89
+
90
+ If your model's ``hparams`` argument is :class:`~argparse.Namespace`
91
+ and .yaml file has hierarchical structure, you need to refactor your model to treat
92
+ ``hparams`` as :class:`~dict`.
93
+ strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
94
+ returned by this module's state dict.
95
+ kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
96
+ hyperparameter values.
97
+
98
+ Return:
99
+ :class:`LightningModule` instance with loaded weights and hyperparameters (if available).
100
+
101
+ Note:
102
+ ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
103
+ **class** to call it instead of the :class:`LightningModule` instance.
104
+
105
+ Example::
106
+
107
+ # load weights without mapping ...
108
+ model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
109
+
110
+ # or load weights mapping all weights from GPU 1 to GPU 0 ...
111
+ map_location = {'cuda:1':'cuda:0'}
112
+ model = MyLightningModule.load_from_checkpoint(
113
+ 'path/to/checkpoint.ckpt',
114
+ map_location=map_location
115
+ )
116
+
117
+ # or load weights and hyperparameters from separate files.
118
+ model = MyLightningModule.load_from_checkpoint(
119
+ 'path/to/checkpoint.ckpt',
120
+ hparams_file='/path/to/hparams_file.yaml'
121
+ )
122
+
123
+ # override some of the params with new values
124
+ model = MyLightningModule.load_from_checkpoint(
125
+ PATH,
126
+ num_layers=128,
127
+ pretrained_ckpt_path=NEW_PATH,
128
+ )
129
+
130
+ # predict
131
+ pretrained_model.eval()
132
+ pretrained_model.freeze()
133
+ y_hat = pretrained_model(x)
134
+ """
135
+ with pl_legacy_patch():
136
+ if map_location is not None:
137
+ checkpoint = pl_load(checkpoint_path, map_location=map_location)
138
+ else:
139
+ checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
140
+
141
+ if hparams_file is not None:
142
+ extension = hparams_file.split(".")[-1]
143
+ if extension.lower() == "csv":
144
+ hparams = load_hparams_from_tags_csv(hparams_file)
145
+ elif extension.lower() in ("yml", "yaml"):
146
+ hparams = load_hparams_from_yaml(hparams_file)
147
+ else:
148
+ raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
149
+
150
+ hparams["on_gpu"] = False
151
+
152
+ # overwrite hparams by the given file
153
+ checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
154
+
155
+ # for past checkpoint need to add the new key
156
+ if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
157
+ checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
158
+ # override the hparams with values that were passed in
159
+ checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
160
+
161
+ model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
162
+ return model
163
+
164
+ @classmethod
165
+ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cls_kwargs_new):
166
+ cls_spec = inspect.getfullargspec(cls.__init__)
167
+ cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
168
+
169
+ self_var, args_var, kwargs_var = parse_class_init_keys(cls)
170
+ drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
171
+ cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))
172
+
173
+ cls_kwargs_loaded = {}
174
+ # pass in the values we saved automatically
175
+ if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
176
+
177
+ # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
178
+ for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
179
+ cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
180
+
181
+ # 2. Try to restore model hparams from checkpoint using the new key
182
+ _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY
183
+ cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))
184
+
185
+ # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
186
+ cls_kwargs_loaded = _convert_loaded_hparams(
187
+ cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)
188
+ )
189
+
190
+ # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
191
+ args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
192
+ if args_name and args_name in cls_init_args_name:
193
+ cls_kwargs_loaded = {args_name: cls_kwargs_loaded}
194
+
195
+ _cls_kwargs = {}
196
+ _cls_kwargs.update(cls_kwargs_loaded)
197
+ _cls_kwargs.update(cls_kwargs_new)
198
+
199
+ if not cls_spec.varkw:
200
+ # filter kwargs according to class init unless it allows any argument via kwargs
201
+ _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
202
+
203
+ model = cls(**_cls_kwargs)
204
+
205
+ # give model a chance to load something
206
+ model.on_load_checkpoint(checkpoint)
207
+
208
+ # load the state_dict on the model automatically
209
+ keys = model.load_state_dict(checkpoint["state_dict"], strict=strict)
210
+
211
+ if not strict:
212
+ if keys.missing_keys:
213
+ rank_zero_warn(
214
+ f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
215
+ )
216
+ if keys.unexpected_keys:
217
+ rank_zero_warn(
218
+ f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
219
+ )
220
+
221
+ return model
222
+
223
+ # -------------------------
224
+ # OPTIONAL HOOKS
225
+ # -------------------------
226
+ def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None:
227
+ """Hook to do whatever you need right before Slurm manager saves the model.
228
+
229
+ Args:
230
+ checkpoint: A dictionary in which you can save variables to save in a checkpoint.
231
+ Contents need to be pickleable.
232
+
233
+ .. deprecated:: v1.6
234
+ This method is deprecated in v1.6 and will be removed in v1.8.
235
+ Please use ``LightningModule.on_save_checkpoint`` instead.
236
+ """
237
+
238
+ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
239
+ """Hook to do whatever you need right before Slurm manager loads the model.
240
+
241
+ Args:
242
+ checkpoint: A dictionary with variables from the checkpoint.
243
+
244
+ .. deprecated:: v1.6
245
+ This method is deprecated in v1.6 and will be removed in v1.8.
246
+ Please use ``LightningModule.on_load_checkpoint`` instead.
247
+ """
248
+
249
+
250
+ def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object:
251
+ """Convert hparams according given type in callable or string (past) format."""
252
+ # if not hparams type define
253
+ if not hparams_type:
254
+ return model_args
255
+ # if past checkpoint loaded, convert str to callable
256
+ if isinstance(hparams_type, str):
257
+ hparams_type = AttributeDict
258
+ # convert hparams
259
+ return hparams_type(model_args)
260
+
261
+
262
+ def update_hparams(hparams: dict, updates: dict) -> None:
263
+ """Overrides hparams with new values.
264
+
265
+ >>> hparams = {'c': 4}
266
+ >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
267
+ >>> hparams['a']['b'], hparams['c']
268
+ (2, 1)
269
+ >>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
270
+ >>> hparams['a']['b'], hparams['c']
271
+ (4, 7)
272
+
273
+ Args:
274
+ hparams: the original params and also target object
275
+ updates: new params to be used as update
276
+ """
277
+ for k, v in updates.items():
278
+ # if missing, add the key
279
+ if k not in hparams:
280
+ hparams[k] = v
281
+ continue
282
+
283
+ # recurse if dictionary
284
+ if isinstance(v, dict):
285
+ update_hparams(hparams[k], updates[k])
286
+ else:
287
+ # update the value
288
+ hparams.update({k: v})
289
+
290
+
291
+ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
292
+ """Load hparams from a file.
293
+
294
+ >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
295
+ >>> path_csv = os.path.join('.', 'testing-hparams.csv')
296
+ >>> save_hparams_to_tags_csv(path_csv, hparams)
297
+ >>> hparams_new = load_hparams_from_tags_csv(path_csv)
298
+ >>> vars(hparams) == hparams_new
299
+ True
300
+ >>> os.remove(path_csv)
301
+ """
302
+ fs = get_filesystem(tags_csv)
303
+ if not fs.exists(tags_csv):
304
+ rank_zero_warn(f"Missing Tags: {tags_csv}.", category=RuntimeWarning)
305
+ return {}
306
+
307
+ with fs.open(tags_csv, "r", newline="") as fp:
308
+ csv_reader = csv.reader(fp, delimiter=",")
309
+ tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
310
+
311
+ return tags
312
+
313
+
314
+ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
315
+ fs = get_filesystem(tags_csv)
316
+ if not fs.isdir(os.path.dirname(tags_csv)):
317
+ raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
318
+
319
+ if isinstance(hparams, Namespace):
320
+ hparams = vars(hparams)
321
+
322
+ with fs.open(tags_csv, "w", newline="") as fp:
323
+ fieldnames = ["key", "value"]
324
+ writer = csv.DictWriter(fp, fieldnames=fieldnames)
325
+ writer.writerow({"key": "key", "value": "value"})
326
+ for k, v in hparams.items():
327
+ writer.writerow({"key": k, "value": v})
328
+
329
+
330
+ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]:
331
+ """Load hparams from a file.
332
+
333
+ Args:
334
+ config_yaml: Path to config yaml file
335
+ use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
336
+ the hparams will be converted to ``DictConfig`` if possible.
337
+
338
+ >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
339
+ >>> path_yaml = './testing-hparams.yaml'
340
+ >>> save_hparams_to_yaml(path_yaml, hparams)
341
+ >>> hparams_new = load_hparams_from_yaml(path_yaml)
342
+ >>> vars(hparams) == hparams_new
343
+ True
344
+ >>> os.remove(path_yaml)
345
+ """
346
+ fs = get_filesystem(config_yaml)
347
+ if not fs.exists(config_yaml):
348
+ rank_zero_warn(f"Missing Tags: {config_yaml}.", category=RuntimeWarning)
349
+ return {}
350
+
351
+ with fs.open(config_yaml, "r") as fp:
352
+ hparams = yaml.full_load(fp)
353
+
354
+ if _OMEGACONF_AVAILABLE:
355
+ if use_omegaconf:
356
+ try:
357
+ return OmegaConf.create(hparams)
358
+ except (UnsupportedValueType, ValidationError):
359
+ pass
360
+ return hparams
361
+
362
+
363
+ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None:
364
+ """
365
+ Args:
366
+ config_yaml: path to new YAML file
367
+ hparams: parameters to be saved
368
+ use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
369
+ the hparams will be converted to ``DictConfig`` if possible.
370
+
371
+ """
372
+ fs = get_filesystem(config_yaml)
373
+ if not fs.isdir(os.path.dirname(config_yaml)):
374
+ raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
375
+
376
+ # convert Namespace or AD to dict
377
+ if isinstance(hparams, Namespace):
378
+ hparams = vars(hparams)
379
+ elif isinstance(hparams, AttributeDict):
380
+ hparams = dict(hparams)
381
+
382
+ # saving with OmegaConf objects
383
+ if _OMEGACONF_AVAILABLE and use_omegaconf:
384
+ # deepcopy: hparams from user shouldn't be resolved
385
+ hparams = deepcopy(hparams)
386
+ hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True)
387
+ with fs.open(config_yaml, "w", encoding="utf-8") as fp:
388
+ try:
389
+ OmegaConf.save(hparams, fp)
390
+ return
391
+ except (UnsupportedValueType, ValidationError):
392
+ pass
393
+
394
+ if not isinstance(hparams, dict):
395
+ raise TypeError("hparams must be dictionary")
396
+
397
+ hparams_allowed = {}
398
+ # drop parameters which contain some strange datatypes as fsspec
399
+ for k, v in hparams.items():
400
+ try:
401
+ v = v.name if isinstance(v, Enum) else v
402
+ yaml.dump(v)
403
+ except TypeError:
404
+ warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
405
+ hparams[k] = type(v).__name__
406
+ else:
407
+ hparams_allowed[k] = v
408
+
409
+ # saving the standard way
410
+ with fs.open(config_yaml, "w", newline="") as fp:
411
+ yaml.dump(hparams_allowed, fp)
412
+
413
+
414
+ def convert(val: str) -> Union[int, float, bool, str]:
415
+ try:
416
+ return ast.literal_eval(val)
417
+ except (ValueError, SyntaxError) as err:
418
+ log.debug(err)
419
+ return val
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from pytorch_lightning.distributed.dist import LightningDistributed # noqa: F401
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/distributed/dist.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any
15
+
16
+ import torch.distributed
17
+
18
+ from pytorch_lightning.utilities import rank_zero_deprecation
19
+ from pytorch_lightning.utilities.distributed import group as _group
20
+
21
+
22
+ class LightningDistributed:
23
+ """
24
+ .. deprecated:: v1.5
25
+ This class is deprecated in v1.5 and will be removed in v1.7.
26
+ The broadcast logic will be moved to the :class:`DDPStrategy` and :class`DDPSpawnStrategy` classes.
27
+
28
+ """
29
+
30
+ def __init__(self, rank=None, device=None):
31
+ rank_zero_deprecation(
32
+ "LightningDistributed is deprecated in v1.5 and will be removed in v1.7."
33
+ "Broadcast logic is implemented directly in the :class:`Strategy` implementations."
34
+ )
35
+ self.rank = rank
36
+ self.device = device
37
+
38
+ def broadcast(self, obj: Any, group=_group.WORLD):
39
+ # always wrap into a list so it can be broadcasted.
40
+ obj = [obj]
41
+
42
+ if self.rank != 0:
43
+ obj = [None] * len(obj)
44
+
45
+ torch.distributed.broadcast_object_list(obj, 0, group=group or _group.WORLD)
46
+
47
+ return obj[0]
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (3.51 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/__pycache__/layer_sync.cpython-38.pyc ADDED
Binary file (3.19 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from pytorch_lightning.plugins.environments.bagua_environment import BaguaEnvironment # noqa: F401
15
+ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401
16
+ from pytorch_lightning.plugins.environments.kubeflow_environment import KubeflowEnvironment # noqa: F401
17
+ from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment # noqa: F401
18
+ from pytorch_lightning.plugins.environments.lsf_environment import LSFEnvironment # noqa: F401
19
+ from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401
20
+ from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (897 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/__pycache__/bagua_environment.cpython-38.pyc ADDED
Binary file (2.78 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/bagua_environment.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+
18
+ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class BaguaEnvironment(ClusterEnvironment):
24
+ """Environment for distributed training with `Bagua <https://tutorials.baguasys.com/>`_"""
25
+
26
+ @property
27
+ def creates_processes_externally(self) -> bool:
28
+ return True
29
+
30
+ @property
31
+ def main_address(self) -> str:
32
+ return os.environ.get("MASTER_ADDR", "127.0.0.1")
33
+
34
+ @property
35
+ def main_port(self) -> int:
36
+ return int(os.environ.get("MASTER_PORT", -1))
37
+
38
+ @property
39
+ def service_port(self) -> int:
40
+ return int(os.environ.get("BAGUA_SERVICE_PORT", -1))
41
+
42
+ @staticmethod
43
+ def detect() -> bool:
44
+ return "BAGUA_SERVICE_PORT" in os.environ
45
+
46
+ def world_size(self) -> int:
47
+ return int(os.environ["WORLD_SIZE"])
48
+
49
+ def set_world_size(self, size: int) -> None:
50
+ log.debug("`BaguaEnvironment.set_world_size` was called, but setting world size is not allowed. Ignored.")
51
+
52
+ def global_rank(self) -> int:
53
+ return int(os.environ["RANK"])
54
+
55
+ def set_global_rank(self, rank: int) -> None:
56
+ log.debug("`BaguaEnvironment.set_global_rank` was called, but setting global rank is not allowed. Ignored.")
57
+
58
+ def local_rank(self) -> int:
59
+ return int(os.environ.get("LOCAL_RANK", 0))
60
+
61
+ def node_rank(self) -> int:
62
+ return int(os.environ.get("NODE_RANK", 0))
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/cluster_environment.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from abc import ABC, abstractmethod
15
+ from typing import Any, Type
16
+
17
+ from pytorch_lightning.utilities import rank_zero_deprecation
18
+
19
+
20
+ class ClusterEnvironment(ABC):
21
+ """Specification of a cluster environment."""
22
+
23
+ def __new__(cls, *args: Any, **kwargs: Any) -> "ClusterEnvironment":
24
+ # TODO: remove in 1.7
25
+ _check_for_deprecated_methods(cls)
26
+ return super().__new__(cls)
27
+
28
+ @property
29
+ @abstractmethod
30
+ def creates_processes_externally(self) -> bool:
31
+ """Whether the environment creates the subprocesses or not."""
32
+
33
+ @property
34
+ @abstractmethod
35
+ def main_address(self) -> str:
36
+ """The main address through which all processes connect and communicate."""
37
+
38
+ @property
39
+ @abstractmethod
40
+ def main_port(self) -> int:
41
+ """An open and configured port in the main node through which all processes communicate."""
42
+
43
+ @staticmethod
44
+ @abstractmethod
45
+ def detect() -> bool:
46
+ """Detects the environment settings corresponding to this cluster and returns ``True`` if they match."""
47
+
48
+ @abstractmethod
49
+ def world_size(self) -> int:
50
+ """The number of processes across all devices and nodes."""
51
+
52
+ @abstractmethod
53
+ def set_world_size(self, size: int) -> None:
54
+ pass
55
+
56
+ @abstractmethod
57
+ def global_rank(self) -> int:
58
+ """The rank (index) of the currently running process across all nodes and devices."""
59
+
60
+ @abstractmethod
61
+ def set_global_rank(self, rank: int) -> None:
62
+ pass
63
+
64
+ @abstractmethod
65
+ def local_rank(self) -> int:
66
+ """The rank (index) of the currently running process inside of the current node."""
67
+
68
+ @abstractmethod
69
+ def node_rank(self) -> int:
70
+ """The rank (index) of the node on which the current process runs."""
71
+
72
+ def teardown(self) -> None:
73
+ """Clean up any state set after execution finishes."""
74
+ pass
75
+
76
+
77
+ def _check_for_deprecated_methods(cls: Type[ClusterEnvironment]) -> None:
78
+ if hasattr(cls, "master_address") and callable(cls.master_address):
79
+ rank_zero_deprecation(
80
+ f"`{cls.__name__}.master_address` has been deprecated in v1.6 and will be removed in v1.7."
81
+ " Implement the property `main_address` instead (do not forget to add the `@property` decorator)."
82
+ )
83
+ if hasattr(cls, "master_port") and callable(cls.master_port):
84
+ rank_zero_deprecation(
85
+ f"`{cls.__name__}.master_port` has been deprecated in v1.6 and will be removed in v1.7."
86
+ " Implement the property `main_port` instead (do not forget to add the `@property` decorator)."
87
+ )
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/kubeflow_environment.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+
18
+ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
19
+ from pytorch_lightning.utilities import rank_zero_deprecation
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+
24
+ class KubeflowEnvironment(ClusterEnvironment):
25
+ """Environment for distributed training using the `PyTorchJob`_ operator from `Kubeflow`_
26
+
27
+ .. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/
28
+ .. _Kubeflow: https://www.kubeflow.org
29
+ """
30
+
31
+ def __init__(self) -> None:
32
+ super().__init__()
33
+ # TODO: remove in 1.7
34
+ if hasattr(self, "is_using_kubeflow") and callable(self.is_using_kubeflow):
35
+ rank_zero_deprecation(
36
+ f"`{self.__class__.__name__}.is_using_kubeflow` has been deprecated in v1.6 and will be removed in"
37
+ f" v1.7. Implement the static method `detect()` instead (do not forget to add the `@staticmethod`"
38
+ f" decorator)."
39
+ )
40
+
41
+ @property
42
+ def creates_processes_externally(self) -> bool:
43
+ return True
44
+
45
+ @property
46
+ def main_address(self) -> str:
47
+ return os.environ["MASTER_ADDR"]
48
+
49
+ @property
50
+ def main_port(self) -> int:
51
+ return int(os.environ["MASTER_PORT"])
52
+
53
+ @staticmethod
54
+ def detect() -> bool:
55
+ """Returns ``True`` if the current process was launched using Kubeflow PyTorchJob."""
56
+ required_env_vars = {"KUBERNETES_PORT", "MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"}
57
+ # torchelastic sets these. Make sure we're not in torchelastic
58
+ excluded_env_vars = {"GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
59
+ env_vars = os.environ.keys()
60
+ return required_env_vars.issubset(env_vars) and excluded_env_vars.isdisjoint(env_vars)
61
+
62
+ def world_size(self) -> int:
63
+ return int(os.environ["WORLD_SIZE"])
64
+
65
+ def set_world_size(self, size: int) -> None:
66
+ log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
67
+
68
+ def global_rank(self) -> int:
69
+ return int(os.environ["RANK"])
70
+
71
+ def set_global_rank(self, rank: int) -> None:
72
+ log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
73
+
74
+ def local_rank(self) -> int:
75
+ return 0
76
+
77
+ def node_rank(self) -> int:
78
+ return self.global_rank()
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lightning_environment.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import socket
17
+
18
+ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
19
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
20
+
21
+
22
+ class LightningEnvironment(ClusterEnvironment):
23
+ """The default environment used by Lightning for a single node or free cluster (not managed).
24
+
25
+ There are two modes the Lightning environment can operate with:
26
+
27
+ 1. The user only launches the main process by :code:`python train.py ...` with no additional environment variables
28
+ set. Lightning will spawn new worker processes for distributed training in the current node.
29
+ 2. The user launches all processes manually or with utilities like :code:`torch.distributed.launch`.
30
+ The appropriate environment variables need to be set, and at minimum :code:`LOCAL_RANK`.
31
+
32
+ If the main address and port are not provided, the default environment will choose them
33
+ automatically. It is recommended to use this default environment for single-node distributed
34
+ training as it provides a convenient way to launch the training script.
35
+ """
36
+
37
+ def __init__(self) -> None:
38
+ super().__init__()
39
+ self._main_port: int = -1
40
+ self._global_rank: int = 0
41
+ self._world_size: int = 1
42
+
43
+ @property
44
+ def creates_processes_externally(self) -> bool:
45
+ """Returns whether the cluster creates the processes or not.
46
+
47
+ If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
48
+ process launcher/job scheduler and Lightning will not launch new processes.
49
+ """
50
+ return "LOCAL_RANK" in os.environ
51
+
52
+ @property
53
+ def main_address(self) -> str:
54
+ return os.environ.get("MASTER_ADDR", "127.0.0.1")
55
+
56
+ @property
57
+ def main_port(self) -> int:
58
+ if self._main_port == -1:
59
+ self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port()))
60
+ return self._main_port
61
+
62
+ @staticmethod
63
+ def detect() -> bool:
64
+ return True
65
+
66
+ def world_size(self) -> int:
67
+ return self._world_size
68
+
69
+ def set_world_size(self, size: int) -> None:
70
+ self._world_size = size
71
+
72
+ def global_rank(self) -> int:
73
+ return self._global_rank
74
+
75
+ def set_global_rank(self, rank: int) -> None:
76
+ self._global_rank = rank
77
+ rank_zero_only.rank = rank
78
+
79
+ def local_rank(self) -> int:
80
+ return int(os.environ.get("LOCAL_RANK", 0))
81
+
82
+ def node_rank(self) -> int:
83
+ group_rank = os.environ.get("GROUP_RANK", 0)
84
+ return int(os.environ.get("NODE_RANK", group_rank))
85
+
86
+ def teardown(self) -> None:
87
+ if "WORLD_SIZE" in os.environ:
88
+ del os.environ["WORLD_SIZE"]
89
+
90
+
91
+ def find_free_network_port() -> int:
92
+ """Finds a free port on localhost.
93
+
94
+ It is useful in single-node training when we don't want to connect to a real main node but have to set the
95
+ `MASTER_PORT` environment variable.
96
+ """
97
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
98
+ s.bind(("", 0))
99
+ port = s.getsockname()[1]
100
+ s.close()
101
+ return port
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/lsf_environment.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import socket
17
+ from typing import Dict, List
18
+
19
+ from pytorch_lightning import _logger as log
20
+ from pytorch_lightning.plugins.environments import ClusterEnvironment
21
+ from pytorch_lightning.utilities import rank_zero_deprecation
22
+ from pytorch_lightning.utilities.cloud_io import get_filesystem
23
+
24
+
25
+ class LSFEnvironment(ClusterEnvironment):
26
+ """An environment for running on clusters managed by the LSF resource manager.
27
+
28
+ It is expected that any execution using this ClusterEnvironment was executed
29
+ using the Job Step Manager i.e. ``jsrun``.
30
+
31
+ This plugin expects the following environment variables:
32
+
33
+ ``LSB_JOBID``
34
+ The LSF assigned job ID
35
+
36
+ ``LSB_DJOB_RANKFILE``
37
+ The OpenMPI compatible rank file for the LSF job
38
+
39
+ ``JSM_NAMESPACE_LOCAL_RANK``
40
+ The node local rank for the task. This environment variable is set by ``jsrun``
41
+
42
+ ``JSM_NAMESPACE_SIZE``
43
+ The world size for the task. This environment variable is set by ``jsrun``
44
+
45
+ ``JSM_NAMESPACE_RANK``
46
+ The global rank for the task. This environment variable is set by ``jsrun``
47
+ """
48
+
49
+ def __init__(self) -> None:
50
+ super().__init__()
51
+ # TODO: remove in 1.7
52
+ if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
53
+ rank_zero_deprecation(
54
+ f"`{self.__class__.__name__}.is_using_lsf` has been deprecated in v1.6 and will be removed in v1.7."
55
+ " Implement the static method `detect()` instead (do not forget to add the `@staticmethod` decorator)."
56
+ )
57
+ self._main_address = self._get_main_address()
58
+ self._main_port = self._get_main_port()
59
+ self._node_rank = self._get_node_rank()
60
+ self._set_init_progress_group_env_vars()
61
+
62
+ def _set_init_progress_group_env_vars(self) -> None:
63
+ # set environment variables needed for initializing torch distributed process group
64
+ os.environ["MASTER_ADDR"] = str(self._main_address)
65
+ log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
66
+ os.environ["MASTER_PORT"] = str(self._main_port)
67
+ log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
68
+
69
+ @property
70
+ def creates_processes_externally(self) -> bool:
71
+ """LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
72
+ return True
73
+
74
+ @property
75
+ def main_address(self) -> str:
76
+ """The main address is read from an OpenMPI host rank file in the environment variable
77
+ ``LSB_DJOB_RANKFILE``."""
78
+ return self._main_address
79
+
80
+ @property
81
+ def main_port(self) -> int:
82
+ """The main port is calculated from the LSF job ID."""
83
+ return self._main_port
84
+
85
+ @staticmethod
86
+ def detect() -> bool:
87
+ """Returns ``True`` if the current process was launched using the ``jsrun`` command."""
88
+ required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
89
+ return required_env_vars.issubset(os.environ.keys())
90
+
91
+ def world_size(self) -> int:
92
+ """The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
93
+ world_size = os.environ.get("JSM_NAMESPACE_SIZE")
94
+ if world_size is None:
95
+ raise ValueError(
96
+ "Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
97
+ "Make sure you run your executable with `jsrun`."
98
+ )
99
+ return int(world_size)
100
+
101
+ def set_world_size(self, size: int) -> None:
102
+ log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
103
+
104
+ def global_rank(self) -> int:
105
+ """The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
106
+ global_rank = os.environ.get("JSM_NAMESPACE_RANK")
107
+ if global_rank is None:
108
+ raise ValueError(
109
+ "Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
110
+ "Make sure you run your executable with `jsrun`."
111
+ )
112
+ return int(global_rank)
113
+
114
+ def set_global_rank(self, rank: int) -> None:
115
+ log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
116
+
117
+ def local_rank(self) -> int:
118
+ """The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
119
+ local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
120
+ if local_rank is None:
121
+ raise ValueError(
122
+ "Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
123
+ "Make sure you run your executable with `jsrun`."
124
+ )
125
+ return int(local_rank)
126
+
127
+ def node_rank(self) -> int:
128
+ """The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored
129
+ in ``LSB_DJOB_RANKFILE``."""
130
+ return self._node_rank
131
+
132
+ def _get_node_rank(self) -> int:
133
+ """A helper method for getting the node rank.
134
+
135
+ The node rank is determined by the position of the current node in the list of hosts used in the job. This is
136
+ calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
137
+ """
138
+ hosts = self._read_hosts()
139
+ count: Dict[str, int] = {}
140
+ for host in hosts:
141
+ if host not in count:
142
+ count[host] = len(count)
143
+ return count[socket.gethostname()]
144
+
145
+ @staticmethod
146
+ def _read_hosts() -> List[str]:
147
+ """Read compute hosts that are a part of the compute job.
148
+
149
+ LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
150
+ Each job is assigned a launch node. This launch node will be the first node in the list contained in
151
+ ``LSB_DJOB_RANKFILE``.
152
+ """
153
+ var = "LSB_DJOB_RANKFILE"
154
+ rankfile = os.environ.get(var)
155
+ if rankfile is None:
156
+ raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
157
+ if not rankfile:
158
+ raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")
159
+
160
+ fs = get_filesystem(rankfile)
161
+ with fs.open(rankfile, "r") as f:
162
+ ret = [line.strip() for line in f]
163
+ # remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
164
+ return ret[1:]
165
+
166
+ def _get_main_address(self) -> str:
167
+ """A helper for getting the main address.
168
+
169
+ The main address is assigned to the first node in the list of nodes used for the job.
170
+ """
171
+ hosts = self._read_hosts()
172
+ return hosts[0]
173
+
174
+ @staticmethod
175
+ def _get_main_port() -> int:
176
+ """A helper function for accessing the main port.
177
+
178
+ Uses the LSF job ID so all ranks can compute the main port.
179
+ """
180
+ # check for user-specified main port
181
+ if "MASTER_PORT" in os.environ:
182
+ log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
183
+ return int(os.environ["MASTER_PORT"])
184
+ if "LSB_JOBID" in os.environ:
185
+ port = int(os.environ["LSB_JOBID"])
186
+ # all ports should be in the 10k+ range
187
+ port = port % 1000 + 10000
188
+ log.debug(f"calculated LSF main port: {port}")
189
+ return port
190
+ raise ValueError("Could not find job id in environment variable LSB_JOBID")
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/slurm_environment.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import re
18
+ from typing import Optional
19
+
20
+ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
21
+
22
+ log = logging.getLogger(__name__)
23
+
24
+
25
+ class SLURMEnvironment(ClusterEnvironment):
26
+ """Cluster environment for training on a cluster managed by SLURM.
27
+
28
+ Args:
29
+ auto_requeue: Whether automatic job resubmission is enabled or not. How and under which conditions a job gets
30
+ rescheduled gets determined by the owner of this plugin.
31
+ """
32
+
33
+ def __init__(self, auto_requeue: bool = True) -> None:
34
+ super().__init__()
35
+ self.auto_requeue = auto_requeue
36
+
37
+ @property
38
+ def creates_processes_externally(self) -> bool:
39
+ return True
40
+
41
+ @property
42
+ def main_address(self) -> str:
43
+ # figure out the root node addr
44
+ slurm_nodelist = os.environ.get("SLURM_NODELIST")
45
+ if slurm_nodelist:
46
+ root_node = slurm_nodelist.split(" ")[0].split(",")[0]
47
+ else:
48
+ root_node = "127.0.0.1"
49
+
50
+ root_node = self.resolve_root_node_address(root_node)
51
+ os.environ["MASTER_ADDR"] = root_node
52
+ log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
53
+ return root_node
54
+
55
+ @property
56
+ def main_port(self) -> int:
57
+ # -----------------------
58
+ # SLURM JOB = PORT number
59
+ # -----------------------
60
+ # this way every process knows what port to use
61
+ job_id = os.environ.get("SLURM_JOB_ID")
62
+ if job_id is not None:
63
+ # use the last 4 numbers in the job id as the id
64
+ default_port = job_id[-4:]
65
+ # all ports should be in the 10k+ range
66
+ default_port = int(default_port) + 15000
67
+ else:
68
+ default_port = 12910
69
+
70
+ # -----------------------
71
+ # PORT NUMBER = MASTER_PORT
72
+ # -----------------------
73
+ # in case the user passed it in
74
+ if "MASTER_PORT" in os.environ:
75
+ default_port = int(os.environ["MASTER_PORT"])
76
+ else:
77
+ os.environ["MASTER_PORT"] = str(default_port)
78
+
79
+ log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
80
+ return default_port
81
+
82
+ @staticmethod
83
+ def detect() -> bool:
84
+ """Returns ``True`` if the current process was launched on a SLURM cluster."""
85
+ return "SLURM_NTASKS" in os.environ
86
+
87
+ @staticmethod
88
+ def job_name() -> Optional[str]:
89
+ return os.environ.get("SLURM_JOB_NAME")
90
+
91
+ @staticmethod
92
+ def job_id() -> Optional[int]:
93
+ # in interactive mode, don't make logs use the same job id
94
+ in_slurm_interactive_mode = SLURMEnvironment.job_name() == "bash"
95
+ if in_slurm_interactive_mode:
96
+ return None
97
+
98
+ job_id = os.environ.get("SLURM_JOB_ID")
99
+ if job_id is None:
100
+ return None
101
+ try:
102
+ return int(job_id)
103
+ except ValueError:
104
+ return None
105
+
106
+ def world_size(self) -> int:
107
+ return int(os.environ["SLURM_NTASKS"])
108
+
109
+ def set_world_size(self, size: int) -> None:
110
+ log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
111
+
112
+ def global_rank(self) -> int:
113
+ return int(os.environ["SLURM_PROCID"])
114
+
115
+ def set_global_rank(self, rank: int) -> None:
116
+ log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
117
+
118
+ def local_rank(self) -> int:
119
+ return int(os.environ["SLURM_LOCALID"])
120
+
121
+ def node_rank(self) -> int:
122
+ return int(os.environ["SLURM_NODEID"])
123
+
124
+ def resolve_root_node_address(self, root_node: str) -> str:
125
+ if "[" in root_node:
126
+ name, numbers = root_node.split("[", maxsplit=1)
127
+ number = numbers.split(",", maxsplit=1)[0]
128
+ if "-" in number:
129
+ number = number.split("-")[0]
130
+
131
+ number = re.sub("[^0-9]", "", number)
132
+ root_node = name + number
133
+
134
+ return root_node
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/environments/torchelastic_environment.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+
18
+ import torch.distributed
19
+
20
+ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
21
+ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_9_1
22
+ from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ class TorchElasticEnvironment(ClusterEnvironment):
28
+ """Environment for fault-tolerant and elastic training with `torchelastic <https://pytorch.org/elastic/>`_"""
29
+
30
+ def __init__(self) -> None:
31
+ super().__init__()
32
+ # TODO: remove in 1.7
33
+ if hasattr(self, "is_using_torchelastic") and callable(self.is_using_torchelastic):
34
+ rank_zero_deprecation(
35
+ f"`{self.__class__.__name__}.is_using_torchelastic` has been deprecated in v1.6 and will be removed in"
36
+ " v1.7. Implement the static method `detect()` instead (do not forget to add the `@staticmethod`"
37
+ " decorator)."
38
+ )
39
+
40
+ @property
41
+ def creates_processes_externally(self) -> bool:
42
+ return True
43
+
44
+ @property
45
+ def main_address(self) -> str:
46
+ if "MASTER_ADDR" not in os.environ:
47
+ rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
48
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
49
+ log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
50
+ return os.environ["MASTER_ADDR"]
51
+
52
+ @property
53
+ def main_port(self) -> int:
54
+ if "MASTER_PORT" not in os.environ:
55
+ rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
56
+ os.environ["MASTER_PORT"] = "12910"
57
+ log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
58
+
59
+ return int(os.environ["MASTER_PORT"])
60
+
61
+ @staticmethod
62
+ def detect() -> bool:
63
+ """Returns ``True`` if the current process was launched using the torchelastic command."""
64
+ if _TORCH_GREATER_EQUAL_1_9_1:
65
+ # if not available (for example on MacOS), `is_torchelastic_launched` is not defined
66
+ return torch.distributed.is_available() and torch.distributed.is_torchelastic_launched()
67
+ required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
68
+ return required_env_vars.issubset(os.environ.keys())
69
+
70
+ def world_size(self) -> int:
71
+ return int(os.environ["WORLD_SIZE"])
72
+
73
+ def set_world_size(self, size: int) -> None:
74
+ log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
75
+
76
+ def global_rank(self) -> int:
77
+ return int(os.environ["RANK"])
78
+
79
+ def set_global_rank(self, rank: int) -> None:
80
+ log.debug(
81
+ "TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
82
+ )
83
+
84
+ def local_rank(self) -> int:
85
+ return int(os.environ["LOCAL_RANK"])
86
+
87
+ def node_rank(self) -> int:
88
+ return int(os.environ.get("GROUP_RANK", 0))
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
15
+ from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO # noqa: F401
16
+ from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
17
+ from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO # noqa: F401
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/__pycache__/xla_plugin.cpython-38.pyc ADDED
Binary file (2.38 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/checkpoint_plugin.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from abc import ABC, abstractmethod
15
+ from typing import Any, Dict, Optional
16
+
17
+ from pytorch_lightning.utilities.types import _PATH
18
+
19
+
20
+ class CheckpointIO(ABC):
21
+ """Interface to save/load checkpoints as they are saved through the ``Strategy``.
22
+
23
+ Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may
24
+ require particular handling depending on the plugin.
25
+
26
+ In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it
27
+ to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``.
28
+
29
+ .. note::
30
+
31
+ For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not
32
+ modifiable.
33
+ """
34
+
35
+ @abstractmethod
36
+ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
37
+ """Save model/training states as a checkpoint file through state-dump and file-write.
38
+
39
+ Args:
40
+ checkpoint: dict containing model and trainer state
41
+ path: write-target path
42
+ storage_options: Optional parameters when saving the model/training states.
43
+ """
44
+
45
+ @abstractmethod
46
+ def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
47
+ """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
48
+
49
+ Args:
50
+ path: Path to checkpoint
51
+ storage_options: Optional parameters when loading the model/training states.
52
+
53
+ Returns: The loaded checkpoint.
54
+ """
55
+
56
+ @abstractmethod
57
+ def remove_checkpoint(self, path: _PATH) -> None:
58
+ """Remove checkpoint file from the filesystem.
59
+
60
+ Args:
61
+ path: Path to checkpoint
62
+ """
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/hpu_plugin.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+
20
+ from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
21
+ from pytorch_lightning.utilities.apply_func import move_data_to_device
22
+ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
23
+ from pytorch_lightning.utilities.types import _PATH
24
+
25
+
26
+ class HPUCheckpointIO(TorchCheckpointIO):
27
+ """CheckpointIO to save checkpoints for HPU training strategies."""
28
+
29
+ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
30
+ """Save model/training states as a checkpoint file through state-dump and file-write.
31
+
32
+ Args:
33
+ checkpoint: dict containing model and trainer state
34
+ path: write-target path
35
+ storage_options: not used in ``XLACheckpointIO.save_checkpoint``
36
+
37
+ Raises:
38
+ TypeError:
39
+ If ``storage_options`` arg is passed in
40
+ """
41
+ if storage_options is not None:
42
+ raise TypeError(
43
+ "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
44
+ f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
45
+ " to define how you'd like to use `storage_options`."
46
+ )
47
+ fs = get_filesystem(path)
48
+ fs.makedirs(os.path.dirname(path), exist_ok=True)
49
+
50
+ checkpoint = move_data_to_device(checkpoint, torch.device("cpu"))
51
+ # write the checkpoint dictionary to the provided path
52
+ atomic_save(checkpoint, path)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import os
16
+ from typing import Any, Callable, Dict, Optional
17
+
18
+ import pytorch_lightning as pl
19
+ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
20
+ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
21
+ from pytorch_lightning.utilities.cloud_io import load as pl_load
22
+ from pytorch_lightning.utilities.rank_zero import rank_zero_warn
23
+ from pytorch_lightning.utilities.types import _PATH
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+
28
+ class TorchCheckpointIO(CheckpointIO):
29
+ """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
30
+ respectively, common for most use cases."""
31
+
32
+ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
33
+ """Save model/training states as a checkpoint file through state-dump and file-write.
34
+
35
+ Args:
36
+ checkpoint: dict containing model and trainer state
37
+ path: write-target path
38
+ storage_options: not used in ``TorchCheckpointIO.save_checkpoint``
39
+
40
+ Raises:
41
+ TypeError:
42
+ If ``storage_options`` arg is passed in
43
+ """
44
+ if storage_options is not None:
45
+ raise TypeError(
46
+ "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
47
+ f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
48
+ " to define how you'd like to use `storage_options`."
49
+ )
50
+ fs = get_filesystem(path)
51
+ fs.makedirs(os.path.dirname(path), exist_ok=True)
52
+ try:
53
+ # write the checkpoint dictionary on the file
54
+ atomic_save(checkpoint, path)
55
+ except AttributeError as err:
56
+ # todo (sean): is this try catch necessary still?
57
+ # https://github.com/PyTorchLightning/pytorch-lightning/pull/431
58
+ key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
59
+ checkpoint.pop(key, None)
60
+ rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
61
+ atomic_save(checkpoint, path)
62
+
63
+ def load_checkpoint(
64
+ self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
65
+ ) -> Dict[str, Any]:
66
+ """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of
67
+ files.
68
+
69
+ Args:
70
+ path: Path to checkpoint
71
+ map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
72
+ locations.
73
+
74
+ Returns: The loaded checkpoint.
75
+
76
+ Raises:
77
+ FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem
78
+ """
79
+
80
+ # Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
81
+ fs = get_filesystem(path)
82
+ if not fs.exists(path):
83
+ raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.")
84
+
85
+ return pl_load(path, map_location=map_location)
86
+
87
+ def remove_checkpoint(self, path: _PATH) -> None:
88
+ """Remove checkpoint file from the filesystem.
89
+
90
+ Args:
91
+ path: Path to checkpoint
92
+ """
93
+ fs = get_filesystem(path)
94
+ if fs.exists(path):
95
+ fs.rm(path, recursive=True)
96
+ log.debug(f"Removed checkpoint: {path}")
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/io/xla_plugin.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Any, Dict, Optional
16
+
17
+ from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
18
+ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE
19
+ from pytorch_lightning.utilities.apply_func import apply_to_collection
20
+ from pytorch_lightning.utilities.cloud_io import get_filesystem
21
+ from pytorch_lightning.utilities.types import _PATH
22
+
23
+ if _TPU_AVAILABLE:
24
+ import torch_xla.core.xla_model as xm
25
+
26
+ if _OMEGACONF_AVAILABLE:
27
+ from omegaconf import DictConfig, ListConfig, OmegaConf
28
+
29
+
30
+ class XLACheckpointIO(TorchCheckpointIO):
31
+ """CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies."""
32
+
33
+ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
34
+ """Save model/training states as a checkpoint file through state-dump and file-write.
35
+
36
+ Args:
37
+ checkpoint: dict containing model and trainer state
38
+ path: write-target path
39
+ storage_options: not used in ``XLACheckpointIO.save_checkpoint``
40
+
41
+ Raises:
42
+ TypeError:
43
+ If ``storage_options`` arg is passed in
44
+ """
45
+ if storage_options is not None:
46
+ raise TypeError(
47
+ "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
48
+ f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
49
+ " to define how you'd like to use `storage_options`."
50
+ )
51
+ fs = get_filesystem(path)
52
+ fs.makedirs(os.path.dirname(path), exist_ok=True)
53
+ # Todo: TypeError: 'mappingproxy' object does not support item assignment
54
+ # Ref: https://github.com/pytorch/xla/issues/2773
55
+ if _OMEGACONF_AVAILABLE:
56
+ checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
57
+ xm.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, path)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright The PyTorch Lightning team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
15
+ from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401
16
+ from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
17
+ from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
18
+ FullyShardedNativeMixedPrecisionPlugin,
19
+ )
20
+ from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin # noqa: F401
21
+ from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401
22
+ from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
23
+ from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
24
+ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
25
+ from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
26
+ from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401
27
+ from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin # noqa: F401
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/apex_amp.cpython-38.pyc ADDED
Binary file (3.74 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/deepspeed.cpython-38.pyc ADDED
Binary file (3.86 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/double.cpython-38.pyc ADDED
Binary file (3.99 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/fully_sharded_native_amp.cpython-38.pyc ADDED
Binary file (999 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/mixed.cpython-38.pyc ADDED
Binary file (719 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/__pycache__/native_amp.cpython-38.pyc ADDED
Binary file (4.31 kB). View file