File size: 10,224 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | # Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Vendored from Lightning-AI/pytorch-lightning commit:
9bcba1c1e82b45e10f948dc28fc12f4cf04ab736
Source:
https://github.com/Lightning-AI/pytorch-lightning/blob/9bcba1c1e82b45e10f948dc28fc12f4cf04ab736/src/lightning/pytorch/callbacks/weight_averaging.py
"""
import itertools
from copy import deepcopy
from typing import Any, Optional, Union
import torch
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
from typing_extensions import override
import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import STEP_OUTPUT
class WeightAveraging(Callback):
def __init__(
self,
device: Optional[Union[torch.device, str, int]] = None,
use_buffers: bool = True,
**kwargs: Any,
) -> None:
if isinstance(device, str):
self._device: Optional[Union[torch.device, int]] = torch.device(device)
else:
self._device = device
self._use_buffers = use_buffers
self._kwargs = kwargs
self._average_model: Optional[AveragedModel] = None
self._latest_update_step = 0
self._latest_update_epoch = -1
def should_update(
self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None
) -> bool:
return step_idx is not None
@override
def setup(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str
) -> None:
if stage == "fit":
device = self._device or pl_module.device
if is_overridden("configure_model", pl_module):
rank_zero_warn(
"You're using the WeightAveraging callback with a model that overrides the configure_model "
"callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory."
)
pl_module.configure_model()
self._average_model = AveragedModel(
model=pl_module,
device=device,
use_buffers=self._use_buffers,
**self._kwargs,
)
@override
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
step_idx = trainer.global_step - 1
if (trainer.global_step > self._latest_update_step) and self.should_update(
step_idx=step_idx
):
assert self._average_model is not None
self._average_model.update_parameters(pl_module)
self._latest_update_step = trainer.global_step
@override
def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(
epoch_idx=trainer.current_epoch
):
assert self._average_model is not None
self._average_model.update_parameters(pl_module)
self._latest_update_epoch = trainer.current_epoch
@override
def on_train_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
assert self._average_model is not None
self._copy_average_to_current(pl_module)
@override
def on_validation_epoch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
if self._average_model is not None:
self._swap_models(pl_module)
@override
def on_validation_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
if self._average_model is not None:
self._swap_models(pl_module)
@override
def state_dict(self) -> dict[str, Any]:
return {"latest_update_step": self._latest_update_step}
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self._latest_update_step = state_dict["latest_update_step"]
@override
def on_save_checkpoint(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: dict[str, Any],
) -> None:
if self._average_model is None:
rank_zero_info(
"You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state "
"of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the "
"average model parameters will be saved to the state_dict in the checkpoint."
)
else:
average_model_state = self._average_model.state_dict()
checkpoint["current_model_state"] = checkpoint["state_dict"]
checkpoint["state_dict"] = {
name[7:]: value
for name, value in average_model_state.items()
if name.startswith("module.")
}
checkpoint["averaging_state"] = {
name: value
for name, value in average_model_state.items()
if not name.startswith("module.")
}
@override
def on_load_checkpoint(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: dict[str, Any],
) -> None:
if self._average_model is None:
rank_zero_warn(
"You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The "
"WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, "
"you can ignore this warning. To disable the warning, remove the WeightAveraging callback."
)
elif ("current_model_state" in checkpoint) and (
"averaging_state" in checkpoint
):
rank_zero_info(
"Found current_model_state in the checkpoint. This will be used to initialize the model."
)
average_model_state = {
"module." + name: value
for name, value in checkpoint["state_dict"].items()
}
average_model_state |= checkpoint["averaging_state"]
self._average_model.load_state_dict(average_model_state)
pl_module.load_state_dict(checkpoint["current_model_state"])
else:
rank_zero_warn(
"The checkpoint was not created with WeightAveraging. Both the current and the average model will be "
"initialized with state_dict."
)
self._average_model.module.load_state_dict(
deepcopy(checkpoint["state_dict"]), strict=False
)
def _swap_models(self, pl_module: "pl.LightningModule") -> None:
assert self._average_model is not None
average_params = itertools.chain(
self._average_model.module.parameters(),
self._average_model.module.buffers(),
)
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
tmp = average_param.data.clone()
average_param.data.copy_(current_param.data)
current_param.data.copy_(tmp)
def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
assert self._average_model is not None
average_params = itertools.chain(
self._average_model.module.parameters(),
self._average_model.module.buffers(),
)
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
current_param.data.copy_(average_param.data)
class EMAWeightAveraging(WeightAveraging):
def __init__(
self,
device: Optional[Union[torch.device, str, int]] = None,
use_buffers: bool = True,
decay: float = 0.999,
update_every_n_steps: int = 1,
update_starting_at_step: Optional[int] = None,
update_starting_at_epoch: Optional[int] = None,
**kwargs: Any,
):
super().__init__(
device=device,
use_buffers=use_buffers,
**kwargs,
avg_fn=get_ema_avg_fn(decay=decay),
)
self.update_every_n_steps = update_every_n_steps
self.update_starting_at_step = update_starting_at_step
self.update_starting_at_epoch = update_starting_at_epoch
def should_update(
self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None
) -> bool:
if step_idx is not None:
meets_step_requirement = (
self.update_starting_at_step is None
or step_idx >= self.update_starting_at_step
)
meets_step_frequency = (
self.update_every_n_steps > 0
and step_idx % self.update_every_n_steps == 0
)
if meets_step_requirement and meets_step_frequency:
return True
if epoch_idx is not None:
meets_epoch_requirement = (
self.update_starting_at_epoch is not None
and epoch_idx >= self.update_starting_at_epoch
)
if meets_epoch_requirement:
return True
return False
|