File size: 6,741 Bytes
85ba398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
#!/usr/bin/env python3
"""
Used for EMA tracking a given pytorch module. The user is responsible for calling step()
and setting the appropriate decay
"""
import copy
from dataclasses import dataclass, field
import logging
import torch
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
try:
from amp_C import multi_tensor_l2norm
multi_tensor_l2norm_available = True
except ImportError:
multi_tensor_l2norm_available = False
logger = logging.getLogger(__name__)
@dataclass
class EMAModuleConfig(FairseqDataclass):
ema_decay: float = field(
default=0.9999, metadata={"help": "decay for exponential moving average model"}
)
ema_fp32: bool = field(
default=False,
metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"},
)
add_missing_params: bool = True
log_norms: bool = False
class EMAModule:
"""Exponential Moving Average of Fairseq Models"""
def __init__(
self,
model,
config: EMAModuleConfig,
copy_model=True,
device=None,
skip_keys=None,
):
"""
@param model model to initialize the EMA with
@param config EMAConfig object with configuration like
ema_decay, ema_update_freq, ema_fp32
@param device If provided, copy EMA to this device (e.g. gpu).
Otherwise EMA is in the same device as the model.
"""
self.config = config
if copy_model:
self.model = copy.deepcopy(model)
self.model.requires_grad_(False)
else:
self.model = model
self.config = config
self.decay = config.ema_decay
self.skip_keys = skip_keys or set()
self.add_missing_params = config.add_missing_params
self.fp32_params = {}
if device is not None:
logging.info(f"Copying EMA model to device {device}")
self.model = self.model.to(device=device)
if self.config.ema_fp32:
self.build_fp32_params()
self.log_norms = config.log_norms and multi_tensor_l2norm_available
self.logs = {}
def build_fp32_params(self, state_dict=None):
"""
Store a copy of the EMA params in fp32.
If state dict is passed, the EMA params is copied from
the provided state dict. Otherwise, it is copied from the
current EMA model parameters.
"""
if not self.config.ema_fp32:
raise RuntimeError(
"build_fp32_params should not be called if ema_fp32=False. "
"Use ema_fp32=True if this is really intended."
)
if state_dict is None:
state_dict = self.model.state_dict()
def _to_float(t):
return t.float() if torch.is_floating_point(t) else t
for param_key in state_dict:
if param_key in self.fp32_params:
if param_key == "__sq_mom":
self.fp32_params[param_key] = state_dict[param_key]
else:
self.fp32_params[param_key].copy_(state_dict[param_key])
else:
self.fp32_params[param_key] = _to_float(state_dict[param_key])
if "__sq_mom" in self.fp32_params:
self.fp32_params["__sq_mom"][param_key] = torch.zeros_like(
self.fp32_params[param_key]
)
def restore(self, state_dict, build_fp32_params=False):
"""Load data from a model spec into EMA model"""
self.model.load_state_dict(state_dict, strict=False)
if build_fp32_params:
self.build_fp32_params(state_dict)
def set_decay(self, decay, weight_decay=None):
self.decay = decay
if weight_decay is not None:
self.weight_decay = weight_decay
def get_decay(self):
return self.decay
def _step_internal(self, new_model):
"""One update of the EMA model based on new model weights"""
decay = self.decay
ema_state_dict = {}
ema_params = (
self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
)
new_p = []
ema_p = []
for key, param in new_model.named_parameters():
if isinstance(param, dict):
continue
if not self.add_missing_params and key not in ema_params:
continue
try:
ema_param = ema_params[key]
except KeyError:
ema_param = (
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
)
ema_params[key] = ema_param
if param.shape != ema_param.shape:
raise ValueError(
"incompatible tensor shapes between model param and ema param"
+ "{} vs. {}".format(param.shape, ema_param.shape)
)
if "version" in key:
# Do not decay a model.version pytorch param
continue
lr = 1 - decay
if key in self.skip_keys or not param.requires_grad:
ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
ema_param = ema_params[key]
else:
if self.log_norms:
new_p.append(param)
ema_p.append(ema_param)
ema_param.mul_(1 - lr)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=lr)
ema_state_dict[key] = ema_param
for key, param in new_model.named_buffers():
ema_state_dict[key] = param
if self.log_norms:
if "model_norm" in self.logs:
self.prev_model_norm = self.logs["model_norm"]
chunk_size = 2048 * 32
has_inf = torch.zeros(
(1, 1), dtype=torch.int, device=next(new_model.parameters()).device
)
new_norm = multi_tensor_l2norm(chunk_size, has_inf, [new_p], False)
old_norm = multi_tensor_l2norm(chunk_size, has_inf, [ema_p], False)
self.logs["model_norm"] = new_norm[0]
self.logs["ema_norm"] = old_norm[0]
self.restore(ema_state_dict, build_fp32_params=False)
@torch.no_grad()
def step(self, new_model):
self._step_internal(new_model)
def reverse(self, model):
"""
Load the model parameters from EMA model.
Useful for inference or fine-tuning from the EMA model.
"""
d = self.model.state_dict()
if "_ema" in d:
del d["_ema"]
model.load_state_dict(d, strict=False)
return model
|