File size: 11,540 Bytes
66003a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import itertools
from typing import Any, Dict, List, Mapping, Iterable, Set, Tuple, Union
import hydra
import torch
import torch.nn as nn
from torch import Tensor
# -----------------------------------------------------------------------------
# Optimizer wrapper
# -----------------------------------------------------------------------------
class OptimizerWrapper:
"""Wraps a torch.optim.Optimizer and its schedulers (if any)."""
def __init__(self, optimizer: torch.optim.Optimizer, schedulers=None) -> None:
self.optimizer = optimizer
self.schedulers = schedulers
self._validate_optimizer_schedulers()
self.step_schedulers(0.0)
# ---------------------------------------------------------------------
# Public API mirroring torch.optim.Optimizer
# ---------------------------------------------------------------------
def step(self, where: float = 1.0, closure=None):
"""Update the optimizer & its schedulers."""
self.step_schedulers(where)
return self.optimizer.step(closure)
def zero_grad(self, *args, **kwargs):
return self.optimizer.zero_grad(*args, **kwargs)
def _validate_optimizer_schedulers(self):
if self.schedulers is None:
return
for _, sched_map in enumerate(self.schedulers):
for option, _ in sched_map.items():
assert option in self.optimizer.defaults, (
f"Optimizer option {option} not found in {self.optimizer}. "
f"Valid options are {self.optimizer.defaults.keys()}"
)
def step_schedulers(self, where: float) -> None:
if self.schedulers is None:
return
for i, param_group in enumerate(self.optimizer.param_groups):
for option, scheduler in self.schedulers[i].items():
param_group[option] = scheduler(where)
# -----------------------------------------------------------------------------
# Validation helpers
# -----------------------------------------------------------------------------
def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
"""Ensure param groups are non-overlapping and include all model params."""
for pg in param_groups:
assert len(pg["params"]) == len(set(pg["params"]))
parameters = [set(pg["params"]) for pg in param_groups]
model_parameters = {p for _, p in model.named_parameters()}
for p1, p2 in itertools.permutations(parameters, 2):
assert p1.isdisjoint(p2), "Parameter groups should be disjoint"
assert set.union(*parameters) == model_parameters, (
"Parameter groups must cover ALL model parameters "
f"(found {len(set.union(*parameters))} / {len(model_parameters)})"
)
# -----------------------------------------------------------------------------
# Glob helpers for pattern matching
# -----------------------------------------------------------------------------
from wcmatch import fnmatch
GLOB_FLAGS = (
fnmatch.CASE # case-sensitive
| fnmatch.DOTMATCH # '*' also matches '.'
| fnmatch.EXTMATCH # extended patterns like *(foo|bar)
| fnmatch.SPLIT # "pat1|pat2" works out-of-the-box
)
def get_full_parameter_name(module_name: str, param_name: str) -> str:
return param_name if module_name == "" else f"{module_name}.{param_name}"
def get_module_cls_to_param_names(model: nn.Module) -> Dict[type, Set[str]]:
"""Map each module class to the *immediate* param names it owns."""
mapping: Dict[type, Set[str]] = {}
for module_name, module in model.named_modules():
module_cls = type(module)
mapping.setdefault(module_cls, set())
for pname, _ in module.named_parameters(recurse=False):
mapping[module_cls].add(get_full_parameter_name(module_name, pname))
return mapping
def unix_param_pattern_to_parameter_names(filter_param_names: Union[List[str], None],
parameter_names: Set[str]) -> Set[str]:
if filter_param_names is None:
return set()
allowed = []
for pat in filter_param_names:
matches = set(fnmatch.filter(parameter_names, pat, flags=GLOB_FLAGS))
if not matches:
raise AssertionError(f"Pattern {pat} matched no parameters")
logging.info(f"Matches for param pattern [{pat}]: {matches}")
allowed.append(matches)
return set.union(*allowed)
def unix_module_cls_pattern_to_parameter_names(filter_module_cls_names: Union[List[str], None],
module_cls_to_param_names: Dict[type, Set[str]]) -> Set[str]:
if filter_module_cls_names is None:
return set()
allowed = []
for cls_name in filter_module_cls_names:
module_cls = hydra.utils.get_class(cls_name)
if module_cls not in module_cls_to_param_names:
raise AssertionError(f"Module class {cls_name} not found in model")
params = module_cls_to_param_names[module_cls]
if not params:
raise AssertionError(f"Module class {cls_name} has no parameters")
logging.info(f"Matches for module [{cls_name}]: {params}")
allowed.append(params)
return set.union(*allowed)
def _unix_pattern_to_parameter_names(scheduler_cfg,
parameter_names: Set[str],
module_cls_to_param_names: Dict[type, Set[str]]):
if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
return None
return unix_param_pattern_to_parameter_names(
scheduler_cfg.get("param_names"), parameter_names
).union(
unix_module_cls_pattern_to_parameter_names(
scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
)
)
# -----------------------------------------------------------------------------
# Scheduler helpers
# -----------------------------------------------------------------------------
def set_default_parameters(scheduler_cfgs: List[dict], all_parameter_names: Set[str]):
"""Ensure exactly one scheduler per option acts as the default."""
specified = [cfg["parameter_names"] for cfg in scheduler_cfgs if cfg["parameter_names"]]
default_params = (
all_parameter_names if not specified else all_parameter_names - set.union(*specified)
)
default_count = 0
for cfg in scheduler_cfgs:
if cfg["parameter_names"] is None:
cfg["parameter_names"] = default_params
default_count += 1
assert default_count <= 1, "At most one default scheduler per option"
if default_count == 0:
scheduler_cfgs.append({"parameter_names": default_params})
def name_constraints_to_parameters(param_constraints: List[Set[str]],
named_parameters: Dict[str, Tensor]) -> List[Tensor]:
matching_names = set.intersection(*param_constraints)
return [v for k, v in named_parameters.items() if k in matching_names]
def map_scheduler_cfgs_to_param_groups(all_scheduler_cfgs: Iterable[List[dict]],
named_parameters: Dict[str, Tensor]):
"""Produce param groups & schedulers that torch.optim can consume."""
schedulers: List[Dict[str, Any]] = []
param_groups: List[Dict[str, List[Tensor]]] = []
for cfgs in itertools.product(*all_scheduler_cfgs):
param_constraints = [cfg["parameter_names"] for cfg in cfgs]
matching = name_constraints_to_parameters(param_constraints, named_parameters)
if not matching:
continue # no intersection of params for this combo
schedulers.append({cfg["option"]: cfg["scheduler"] for cfg in cfgs if "option" in cfg})
param_groups.append({"params": matching})
return schedulers, param_groups
# -----------------------------------------------------------------------------
# Public factory functions
# -----------------------------------------------------------------------------
def construct_optimizer(model: nn.Module,
optimizer_conf: Any,
options_conf: Union[Mapping[str, List], None] = None,
param_group_modifiers_conf: Union[List, None] = None,
validate_param_groups: bool = True) -> OptimizerWrapper:
"""Build an OptimizerWrapper from hydra configs.
*No* allowlist handling β we always optimize *all* model parameters.
"""
named_parameters = dict(model.named_parameters())
all_parameter_names = set(named_parameters.keys())
module_cls_to_all_param_names = get_module_cls_to_param_names(model)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# No scheduler case β simple & fast
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if not options_conf:
optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
return OptimizerWrapper(optimizer)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Build option-specific scheduler configs
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
all_scheduler_cfgs: List[List[dict]] = []
for option, cfg_list in scheduler_cfgs_per_option.items():
for cfg in cfg_list:
cfg.option = option # annotate
cfg.parameter_names = _unix_pattern_to_parameter_names(
cfg, all_parameter_names, module_cls_to_all_param_names
)
set_default_parameters(cfg_list, all_parameter_names)
all_scheduler_cfgs.append(cfg_list)
# User-provided modifiers (rare)
if param_group_modifiers_conf:
for modifier in param_group_modifiers_conf:
modifier = hydra.utils.instantiate(modifier)
all_scheduler_cfgs = modifier(scheduler_cfgs=all_scheduler_cfgs, model=model)
# Map scheduler cfg combos to optimizer param groups
schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
all_scheduler_cfgs, named_parameters
)
if validate_param_groups:
validate_param_group_params(param_groups, model)
optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
return OptimizerWrapper(optimizer, schedulers)
def construct_optimizers(model: nn.Module, optim_conf) -> Union[List[OptimizerWrapper], None]:
"""Convenience wrapper producing a *single* OptimizerWrapper list."""
if optim_conf is None:
return None
optimizer = construct_optimizer(
model,
optim_conf.optimizer,
optim_conf.options,
validate_param_groups=True,
)
return [optimizer]
|