| | |
| | |
| |
|
| | |
| | |
| |
|
| | import fnmatch |
| | import inspect |
| | import itertools |
| | import logging |
| | import types |
| | from typing import ( |
| | Any, |
| | Callable, |
| | Dict, |
| | Iterable, |
| | List, |
| | Mapping, |
| | Optional, |
| | Set, |
| | Tuple, |
| | Type, |
| | Union, |
| | ) |
| |
|
| | import hydra |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from omegaconf import DictConfig |
| | from torch import Tensor |
| |
|
| |
|
| | class Optimizer: |
| | def __init__(self, optimizer, schedulers=None) -> None: |
| | self.optimizer = optimizer |
| | self.schedulers = schedulers |
| | self._validate_optimizer_schedulers() |
| | self.step_schedulers(0.0, 0) |
| |
|
| | def _validate_optimizer_schedulers(self): |
| | if self.schedulers is None: |
| | return |
| | for _, set_of_schedulers in enumerate(self.schedulers): |
| | for option, _ in set_of_schedulers.items(): |
| | assert option in self.optimizer.defaults, ( |
| | "Optimizer option " |
| | f"{option} not found in {self.optimizer}. Valid options are " |
| | f"{self.optimizer.defaults.keys()}" |
| | ) |
| |
|
| | def step_schedulers(self, where: float, step: int) -> None: |
| | if self.schedulers is None: |
| | return |
| | for i, param_group in enumerate(self.optimizer.param_groups): |
| | for option, scheduler in self.schedulers[i].items(): |
| | if "step" in inspect.signature(scheduler.__call__).parameters: |
| | new_value = scheduler(step=step, where=where) |
| | elif ( |
| | hasattr(scheduler, "scheduler") |
| | and "step" |
| | in inspect.signature(scheduler.scheduler.__call__).parameters |
| | ): |
| | |
| | new_value = scheduler(step=step, where=where) |
| | else: |
| | new_value = scheduler(where) |
| | param_group[option] = new_value |
| |
|
| | def step(self, where, step, closure=None): |
| | self.step_schedulers(where, step) |
| | return self.optimizer.step(closure) |
| |
|
| | def zero_grad(self, *args, **kwargs): |
| | return self.optimizer.zero_grad(*args, **kwargs) |
| |
|
| |
|
| | def set_default_parameters( |
| | scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str] |
| | ) -> None: |
| | """Set up the "default" scheduler with the right parameters. |
| | |
| | Args: |
| | scheduler_cgfs: A list of scheduler configs, where each scheduler also |
| | specifies which parameters it applies to, based on the names of parameters |
| | or the class of the modules. At most one scheduler is allowed to skip this |
| | specification, which is used as a "default" specification for any remaining |
| | parameters. |
| | all_parameter_names: Names of all the parameters to consider. |
| | """ |
| | constraints = [ |
| | scheduler_cfg.parameter_names |
| | for scheduler_cfg in scheduler_cfgs |
| | if scheduler_cfg.parameter_names is not None |
| | ] |
| | if len(constraints) == 0: |
| | default_params = set(all_parameter_names) |
| | else: |
| | default_params = all_parameter_names - set.union(*constraints) |
| | default_count = 0 |
| | for scheduler_cfg in scheduler_cfgs: |
| | if scheduler_cfg.parameter_names is None: |
| | scheduler_cfg.parameter_names = default_params |
| | default_count += 1 |
| | assert default_count <= 1, "Only one scheduler per option can be default" |
| | if default_count == 0: |
| | |
| | |
| | scheduler_cfgs.append({"parameter_names": default_params}) |
| |
|
| |
|
| | def name_constraints_to_parameters( |
| | param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor] |
| | ) -> List[torch.nn.Parameter]: |
| | """Return parameters which match the intersection of parameter constraints. |
| | |
| | Note that this returns the parameters themselves, not their names. |
| | |
| | Args: |
| | param_constraints: A list, with each element being a set of allowed parameters. |
| | named_parameters: Mapping from a parameter name to the parameter itself. |
| | |
| | Returns: |
| | A list containing the parameters which overlap with _each_ constraint set from |
| | param_constraints. |
| | """ |
| | matching_names = set.intersection(*param_constraints) |
| | return [value for name, value in named_parameters.items() if name in matching_names] |
| |
|
| |
|
| | def map_scheduler_cfgs_to_param_groups( |
| | all_scheduler_cfgs: Iterable[List[Dict]], |
| | named_parameters: Dict[str, Tensor], |
| | ) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]: |
| | """Produce parameter groups corresponding to all the scheduler configs. |
| | |
| | Takes all the scheduler configs, each of which applies to a specific optimizer |
| | option (like "lr" or "weight_decay") and has a set of parameter names which it |
| | applies to, and produces a final set of param groups where each param group |
| | covers all the options which apply to a particular set of parameters. |
| | |
| | Args: |
| | all_scheduler_cfgs: All the scheduler configs covering every option. |
| | named_parameters: Mapping from a parameter name to the parameter itself. |
| | Returns: |
| | Tuple of lists of schedulers and param_groups, where schedulers[i] |
| | applies to param_groups[i]. |
| | """ |
| |
|
| | scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs) |
| | schedulers = [] |
| | param_groups = [] |
| | for scheduler_cfgs in scheduler_cfgs_per_param_group: |
| | param_constraints = [ |
| | scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs |
| | ] |
| | matching_parameters = name_constraints_to_parameters( |
| | param_constraints, named_parameters |
| | ) |
| | if len(matching_parameters) == 0: |
| | continue |
| | schedulers_for_group = { |
| | scheduler_cfg["option"]: scheduler_cfg["scheduler"] |
| | for scheduler_cfg in scheduler_cfgs |
| | if "option" in scheduler_cfg |
| | } |
| | schedulers.append(schedulers_for_group) |
| | param_groups.append({"params": matching_parameters}) |
| | return schedulers, param_groups |
| |
|
| |
|
| | def validate_param_group_params(param_groups: List[Dict], model: nn.Module): |
| | """Check that the param groups are non-overlapping and cover all the parameters. |
| | |
| | Args: |
| | param_groups: List of all param groups |
| | model: Model to validate against. The check ensures that all the model |
| | parameters are part of param_groups |
| | """ |
| | for pg in param_groups: |
| | |
| | assert len(pg["params"]) == len(set(pg["params"])) |
| | parameters = [set(param_group["params"]) for param_group in param_groups] |
| | model_parameters = {parameter for _, parameter in model.named_parameters()} |
| | for p1, p2 in itertools.permutations(parameters, 2): |
| | assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint" |
| | assert set.union(*parameters) == model_parameters, ( |
| | "Scheduler generated param_groups must include all parameters of the model." |
| | f" Found {len(set.union(*parameters))} params whereas model has" |
| | f" {len(model_parameters)} params" |
| | ) |
| |
|
| |
|
| | def unix_module_cls_pattern_to_parameter_names( |
| | filter_module_cls_names: List[str], |
| | module_cls_to_param_names: Dict[Type, str], |
| | ) -> Union[None, Set[str]]: |
| | """Returns param names which pass the filters specified in filter_module_cls_names. |
| | |
| | Args: |
| | filter_module_cls_names: A list of filter strings containing class names, like |
| | ["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"] |
| | module_cls_to_param_names: Mapping from module classes to the parameter names |
| | they contain. See `get_module_cls_to_param_names`. |
| | """ |
| | if filter_module_cls_names is None: |
| | return set() |
| | allowed_parameter_names = [] |
| | for module_cls_name in filter_module_cls_names: |
| | module_cls = hydra.utils.get_class(module_cls_name) |
| | if module_cls not in module_cls_to_param_names: |
| | raise AssertionError( |
| | f"module_cls_name {module_cls_name} does not " |
| | "match any classes in the model" |
| | ) |
| | matching_parameters = module_cls_to_param_names[module_cls] |
| | assert ( |
| | len(matching_parameters) > 0 |
| | ), f"module_cls_name {module_cls_name} does not contain any parameters in the model" |
| | logging.info( |
| | f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} " |
| | ) |
| | allowed_parameter_names.append(matching_parameters) |
| | return set.union(*allowed_parameter_names) |
| |
|
| |
|
| | def unix_param_pattern_to_parameter_names( |
| | filter_param_names: Optional[List[str]], |
| | parameter_names: Dict[str, torch.Tensor], |
| | ) -> Union[None, Set[str]]: |
| | """Returns param names which pass the filters specified in filter_param_names. |
| | |
| | Args: |
| | filter_param_names: A list of unix-style filter strings with optional |
| | wildcards, like ["block.2.*", "block.2.linear.weight"] |
| | module_cls_to_param_names: Mapping from module classes to the parameter names |
| | they contain. See `get_module_cls_to_param_names`. |
| | """ |
| |
|
| | if filter_param_names is None: |
| | return set() |
| | allowed_parameter_names = [] |
| | for param_name in filter_param_names: |
| | matching_parameters = set(fnmatch.filter(parameter_names, param_name)) |
| | assert ( |
| | len(matching_parameters) >= 1 |
| | ), f"param_name {param_name} does not match any parameters in the model" |
| | logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}") |
| | allowed_parameter_names.append(matching_parameters) |
| | return set.union(*allowed_parameter_names) |
| |
|
| |
|
| | def _unix_pattern_to_parameter_names( |
| | scheduler_cfg: DictConfig, |
| | parameter_names: Set[str], |
| | module_cls_to_param_names: Dict[Type, str], |
| | ) -> Union[None, Set[str]]: |
| | """Returns param names which pass the filters specified in scheduler_cfg. |
| | |
| | Args: |
| | scheduler_cfg: The config for the scheduler |
| | parameter_names: The set of all parameter names which will be filtered |
| | """ |
| | if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg: |
| | return None |
| | return unix_param_pattern_to_parameter_names( |
| | scheduler_cfg.get("param_names"), parameter_names |
| | ).union( |
| | unix_module_cls_pattern_to_parameter_names( |
| | scheduler_cfg.get("module_cls_names"), module_cls_to_param_names |
| | ) |
| | ) |
| |
|
| |
|
| | def get_module_cls_to_param_names( |
| | model: nn.Module, param_allowlist: Set[str] = None |
| | ) -> Dict[Type, str]: |
| | """Produce a mapping from all the modules classes to the names of parames they own. |
| | |
| | Only counts a parameter as part of the immediate parent module, i.e. recursive |
| | parents do not count. |
| | |
| | Args: |
| | model: Model to iterate over |
| | param_allowlist: If specified, only these param names will be processed |
| | """ |
| |
|
| | module_cls_to_params = {} |
| | for module_name, module in model.named_modules(): |
| | module_cls = type(module) |
| | module_cls_to_params.setdefault(module_cls, set()) |
| | for param_name, _ in module.named_parameters(recurse=False): |
| | full_param_name = get_full_parameter_name(module_name, param_name) |
| | if param_allowlist is None or full_param_name in param_allowlist: |
| | module_cls_to_params[module_cls].add(full_param_name) |
| | return module_cls_to_params |
| |
|
| |
|
| | def construct_optimizer( |
| | model: torch.nn.Module, |
| | optimizer_conf: Any, |
| | options_conf: Mapping[str, List] = None, |
| | param_group_modifiers_conf: List[Callable] = None, |
| | param_allowlist: Optional[Set[str]] = None, |
| | validate_param_groups=True, |
| | ) -> Optimizer: |
| | """ |
| | Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer |
| | with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay |
| | Batchnorm and/or no-update 1-D parameters support, based on the config. |
| | |
| | Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling |
| | (LARS): https://arxiv.org/abs/1708.03888 |
| | |
| | Args: |
| | model: model to perform stochastic gradient descent |
| | optimization or ADAM optimization. |
| | optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or |
| | ADAM, still missing the params argument which this function provides to |
| | produce the final optimizer |
| | param_group_modifiers_conf: Optional user specified functions which can modify |
| | the final scheduler configs before the optimizer's param groups are built |
| | param_allowlist: The parameters to optimize. Parameters which are not part of |
| | this allowlist will be skipped. |
| | validate_param_groups: If enabled, valides that the produced param_groups don't |
| | overlap and cover all the model parameters. |
| | """ |
| | if param_allowlist is None: |
| | param_allowlist = {name for name, _ in model.named_parameters()} |
| |
|
| | named_parameters = { |
| | name: param |
| | for name, param in model.named_parameters() |
| | if name in param_allowlist |
| | } |
| |
|
| | if not options_conf: |
| | optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values()) |
| | return Optimizer(optimizer) |
| |
|
| | all_parameter_names = { |
| | name for name, _ in model.named_parameters() if name in param_allowlist |
| | } |
| | module_cls_to_all_param_names = get_module_cls_to_param_names( |
| | model, param_allowlist |
| | ) |
| |
|
| | scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf) |
| | all_scheduler_cfgs = [] |
| | for option, scheduler_cfgs in scheduler_cfgs_per_option.items(): |
| | for config in scheduler_cfgs: |
| | config.option = option |
| | config.parameter_names = _unix_pattern_to_parameter_names( |
| | config, all_parameter_names, module_cls_to_all_param_names |
| | ) |
| | set_default_parameters(scheduler_cfgs, all_parameter_names) |
| | all_scheduler_cfgs.append(scheduler_cfgs) |
| |
|
| | if param_group_modifiers_conf: |
| | for custom_param_modifier in param_group_modifiers_conf: |
| | custom_param_modifier = hydra.utils.instantiate(custom_param_modifier) |
| | all_scheduler_cfgs = custom_param_modifier( |
| | scheduler_cfgs=all_scheduler_cfgs, model=model |
| | ) |
| | schedulers, param_groups = map_scheduler_cfgs_to_param_groups( |
| | all_scheduler_cfgs, named_parameters |
| | ) |
| | if validate_param_groups: |
| | validate_param_group_params(param_groups, model) |
| | optimizer = hydra.utils.instantiate(optimizer_conf, param_groups) |
| | return Optimizer(optimizer, schedulers) |
| |
|
| |
|
| | def get_full_parameter_name(module_name, param_name): |
| | if module_name == "": |
| | return param_name |
| | return f"{module_name}.{param_name}" |
| |
|
| |
|
| | class GradientClipper: |
| | """ |
| | Gradient clipping utils that works for DDP |
| | """ |
| |
|
| | def __init__(self, max_norm: float = 1.0, norm_type: int = 2): |
| | assert isinstance(max_norm, (int, float)) or max_norm is None |
| | self.max_norm = max_norm if max_norm is None else float(max_norm) |
| | self.norm_type = norm_type |
| |
|
| | def __call__(self, model: nn.Module): |
| | if self.max_norm is None: |
| | return |
| |
|
| | nn.utils.clip_grad_norm_( |
| | model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type |
| | ) |
| |
|
| |
|
| | class ValueScaler: |
| | def __init__(self, scheduler, mult_val: float): |
| | self.scheduler = scheduler |
| | self.mult_val = mult_val |
| |
|
| | def __call__(self, *args, **kwargs): |
| | val = self.scheduler(*args, **kwargs) |
| | return val * self.mult_val |
| |
|
| |
|
| | def rgetattr(obj, rattrs: str = None): |
| | """ |
| | Like getattr(), but supports dotted notation for nested objects. |
| | rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2 |
| | """ |
| | if rattrs is None: |
| | return obj |
| | attrs = rattrs.split(".") |
| | for attr in attrs: |
| | obj = getattr(obj, attr) |
| | return obj |
| |
|
| |
|
| | def layer_decay_param_modifier( |
| | scheduler_cfgs: List[List[Dict]], |
| | model, |
| | layer_decay_value: float, |
| | layer_decay_min: Optional[float] = None, |
| | apply_to: Optional[str] = None, |
| | overrides: List[Dict] = (), |
| | ) -> List[List[Dict]]: |
| | """ |
| | Args |
| | - scheduler_cfgs: a list of omegaconf.ListConfigs. |
| | Each element in the list is a omegaconfg.DictConfig with the following structure |
| | { |
| | "scheduler": <some fvcore scheduler> |
| | "option": <value> possible options are "lr", "weight_decay" etc. |
| | "parameter_names": Set of str indicating param names that this scheduler applies to |
| | } |
| | - model: a model that implements a method `get_layer_id` that maps layer_name to an integer and |
| | and a method get_num_layers. |
| | Alternatively, use apply_to argument to select a specific component of the model. |
| | - layer_decay_value: float |
| | - layer_decay_min: min val for layer decay |
| | - apply_to: optional arg to select which component of the model to apply the the layer decay modifier to |
| | - overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value". |
| | Returns |
| | - scheduler_configs: same structure as the input, elements can be modified |
| | """ |
| | model = rgetattr(model, apply_to) |
| | num_layers = model.get_num_layers() + 1 |
| | layer_decays = [ |
| | layer_decay_value ** (num_layers - i) for i in range(num_layers + 1) |
| | ] |
| | if layer_decay_min is not None: |
| | layer_decays = [max(val, layer_decay_min) for val in layer_decays] |
| | final_scheduler_cfgs = [] |
| | |
| | for scheduler_cfg_group in scheduler_cfgs: |
| | curr_cfg_group = [] |
| | |
| | for scheduler_cfg in scheduler_cfg_group: |
| | if scheduler_cfg["option"] != "lr": |
| | curr_cfg_group.append(scheduler_cfg) |
| | continue |
| | |
| | |
| | |
| | parameter_names = sorted(scheduler_cfg["parameter_names"]) |
| |
|
| | |
| | layer_cfg_groups = {} |
| | for param_name in parameter_names: |
| | layer_id = num_layers |
| | this_scale = layer_decays[layer_id] |
| | if param_name.startswith(apply_to): |
| | layer_id = model.get_layer_id(param_name) |
| | this_scale = layer_decays[layer_id] |
| | |
| | for override in overrides: |
| | if fnmatch.fnmatchcase(param_name, override["pattern"]): |
| | this_scale = float(override["value"]) |
| | layer_id = override["pattern"] |
| | break |
| |
|
| | if layer_id not in layer_cfg_groups: |
| | curr_param = { |
| | "option": scheduler_cfg["option"], |
| | "scheduler": ValueScaler( |
| | scheduler_cfg["scheduler"], this_scale |
| | ), |
| | "parameter_names": {param_name}, |
| | } |
| | else: |
| | curr_param = layer_cfg_groups[layer_id] |
| | curr_param["parameter_names"].add(param_name) |
| | layer_cfg_groups[layer_id] = curr_param |
| |
|
| | for layer_cfg in layer_cfg_groups.values(): |
| | curr_cfg_group.append(layer_cfg) |
| |
|
| | final_scheduler_cfgs.append(curr_cfg_group) |
| | return final_scheduler_cfgs |
| |
|