Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| import torch.nn as nn | |
| from typing import Any | |
| from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable | |
| from termcolor import colored | |
| def get_missing_parameters_message(keys: List[str]) -> str: | |
| """ | |
| Get a logging-friendly message to report parameter names (keys) that are in | |
| the model but not found in a checkpoint. | |
| Args: | |
| keys (list[str]): List of keys that were not found in the checkpoint. | |
| Returns: | |
| str: message. | |
| """ | |
| groups = _group_checkpoint_keys(keys) | |
| msg = "Some model parameters or buffers are not found in the checkpoint:\n" | |
| msg += "\n".join( | |
| " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items() | |
| ) | |
| return msg | |
| def get_unexpected_parameters_message(keys: List[str]) -> str: | |
| """ | |
| Get a logging-friendly message to report parameter names (keys) that are in | |
| the checkpoint but not found in the model. | |
| Args: | |
| keys (list[str]): List of keys that were not found in the model. | |
| Returns: | |
| str: message. | |
| """ | |
| groups = _group_checkpoint_keys(keys) | |
| msg = "The checkpoint state_dict contains keys that are not used by the model:\n" | |
| msg += "\n".join( | |
| " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() | |
| ) | |
| return msg | |
| def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None: | |
| """ | |
| Strip the prefix in metadata, if any. | |
| Args: | |
| state_dict (OrderedDict): a state-dict to be loaded to the model. | |
| prefix (str): prefix. | |
| """ | |
| keys = sorted(state_dict.keys()) | |
| if not all(len(key) == 0 or key.startswith(prefix) for key in keys): | |
| return | |
| for key in keys: | |
| newkey = key[len(prefix):] | |
| state_dict[newkey] = state_dict.pop(key) | |
| # also strip the prefix in metadata, if any.. | |
| try: | |
| metadata = state_dict._metadata # pyre-ignore | |
| except AttributeError: | |
| pass | |
| else: | |
| for key in list(metadata.keys()): | |
| # for the metadata dict, the key can be: | |
| # '': for the DDP module, which we want to remove. | |
| # 'module': for the actual model. | |
| # 'module.xx.xx': for the rest. | |
| if len(key) == 0: | |
| continue | |
| newkey = key[len(prefix):] | |
| metadata[newkey] = metadata.pop(key) | |
| def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]: | |
| """ | |
| Group keys based on common prefixes. A prefix is the string up to the final | |
| "." in each key. | |
| Args: | |
| keys (list[str]): list of parameter names, i.e. keys in the model | |
| checkpoint dict. | |
| Returns: | |
| dict[list]: keys with common prefixes are grouped into lists. | |
| """ | |
| groups = defaultdict(list) | |
| for key in keys: | |
| pos = key.rfind(".") | |
| if pos >= 0: | |
| head, tail = key[:pos], [key[pos + 1:]] | |
| else: | |
| head, tail = key, [] | |
| groups[head].extend(tail) | |
| return groups | |
| def _group_to_str(group: List[str]) -> str: | |
| """ | |
| Format a group of parameter name suffixes into a loggable string. | |
| Args: | |
| group (list[str]): list of parameter name suffixes. | |
| Returns: | |
| str: formated string. | |
| """ | |
| if len(group) == 0: | |
| return "" | |
| if len(group) == 1: | |
| return "." + group[0] | |
| return ".{" + ", ".join(group) + "}" | |
| def _named_modules_with_dup( | |
| model: nn.Module, prefix: str = "" | |
| ) -> Iterable[Tuple[str, nn.Module]]: | |
| """ | |
| The same as `model.named_modules()`, except that it includes | |
| duplicated modules that have more than one name. | |
| """ | |
| yield prefix, model | |
| for name, module in model._modules.items(): # pyre-ignore | |
| if module is None: | |
| continue | |
| submodule_prefix = prefix + ("." if prefix else "") + name | |
| yield from _named_modules_with_dup(module, submodule_prefix) |