Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Callable, Mapping, Any | |
| import torch | |
| import tree | |
| from allenact.utils.system import get_logger | |
| from torch import nn | |
| from .file_utils import f_join | |
| def freeze_module(module: nn.Module | torch.Tensor) -> nn.Module: | |
| if torch.is_tensor(module): | |
| module.requires_grad = False | |
| return module | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| module.eval() | |
| return module | |
| def unfreeze_module(module: nn.Module | torch.Tensor) -> nn.Module: | |
| if torch.is_tensor(module): | |
| module.requires_grad = True | |
| return module | |
| for param in module.parameters(): | |
| param.requires_grad = True | |
| module.train() | |
| return module | |
| def freeze_bn(module: nn.Module): | |
| for mod in module.modules(): | |
| if "BatchNorm" in type(mod).__name__: | |
| mod.momentum = 0.0 | |
| mod.eval() | |
| return module | |
| def replace_submodules( | |
| root_module: nn.Module, | |
| predicate: Callable[[nn.Module], bool], | |
| func: Callable[[nn.Module], nn.Module], | |
| ) -> nn.Module: | |
| """ | |
| predicate: Return true if the module is to be replaced. | |
| func: Return new module to use. | |
| """ | |
| if predicate(root_module): | |
| return func(root_module) | |
| bn_list = [ | |
| k.split(".") | |
| for k, m in root_module.named_modules(remove_duplicate=True) | |
| if predicate(m) | |
| ] | |
| for *parent, k in bn_list: | |
| parent_module = root_module | |
| if len(parent) > 0: | |
| parent_module = root_module.get_submodule(".".join(parent)) | |
| if isinstance(parent_module, nn.Sequential): | |
| src_module = parent_module[int(k)] | |
| else: | |
| src_module = getattr(parent_module, k) | |
| tgt_module = func(src_module) | |
| if isinstance(parent_module, nn.Sequential): | |
| parent_module[int(k)] = tgt_module | |
| else: | |
| setattr(parent_module, k, tgt_module) | |
| # verify that all BN are replaced | |
| bn_list = [ | |
| k.split(".") | |
| for k, m in root_module.named_modules(remove_duplicate=True) | |
| if predicate(m) | |
| ] | |
| assert len(bn_list) == 0 | |
| return root_module | |
| def bn_to_gn( | |
| module: nn.Module, | |
| group_ratio: int = 16, | |
| device: str | int | torch.device | None = None, | |
| ): | |
| return replace_submodules( | |
| root_module=module, | |
| predicate=lambda x: isinstance(x, (nn.BatchNorm2d, nn.BatchNorm1d)), | |
| func=lambda x: nn.GroupNorm( | |
| num_groups=x.num_features // group_ratio, | |
| num_channels=x.num_features, | |
| device=device, | |
| ), | |
| ) | |
| def torch_load(*fpath: str, map_location="cpu") -> dict: | |
| fpath = str(f_join(*fpath)) | |
| return torch.load(fpath, map_location=map_location) | |
| def tree_value_at_path(obj, paths: Tuple): | |
| try: | |
| for p in paths: | |
| obj = obj[p] | |
| return obj | |
| except Exception as e: | |
| raise ValueError(f"{e}\n\n-- Incorrect nested path {paths} for object: {obj}.") | |
| def implements_method(object, method: str): | |
| """ | |
| Returns: | |
| True if object implements a method | |
| """ | |
| return hasattr(object, method) and callable(getattr(object, method)) | |
| def load_state_dict( | |
| objects, | |
| states, | |
| strip_prefix: str | None = None, | |
| strict: bool = True, | |
| filter_prefix: list[str] | str | None = None, | |
| verbose: bool = True, | |
| ): | |
| """ | |
| Args: | |
| strict: objects and states must match exactly | |
| strip_prefix: only match the keys that have the prefix, and strip it | |
| """ | |
| def _load(paths, obj): | |
| if not implements_method(obj, "load_state_dict"): | |
| raise ValueError( | |
| f"Object {type(obj)} does not support load_state_dict() method" | |
| ) | |
| try: | |
| state = tree_value_at_path(states, paths) | |
| except ValueError: # paths do not exist in `states` structure | |
| if strict: | |
| raise | |
| else: | |
| return | |
| if strip_prefix: | |
| assert isinstance(strip_prefix, str) | |
| state = { | |
| k[len(strip_prefix) :]: v | |
| for k, v in state.items() | |
| if k.startswith(strip_prefix) | |
| } | |
| if filter_prefix: | |
| state = { | |
| k: v | |
| for k, v in state.items() | |
| if all( | |
| not k.startswith(p) | |
| for p in ( | |
| (filter_prefix,) | |
| if isinstance(filter_prefix, str) | |
| else filter_prefix | |
| ) | |
| ) | |
| } | |
| if isinstance(obj, nn.Module): | |
| return obj.load_state_dict(state, strict=strict) | |
| else: | |
| return obj.load_state_dict(state) | |
| keys = tree.map_structure_with_path(_load, objects) | |
| if not strict and verbose: | |
| if keys.missing_keys: | |
| get_logger().debug( | |
| f'Missing key(s) in state_dict: {", ".join(keys.missing_keys)}' | |
| ) | |
| if keys.unexpected_keys: | |
| get_logger().debug( | |
| f'Unexpected key(s) in state dict: {", ".join(keys.unexpected_keys)}' | |
| ) | |
| return keys | |
| def load_pl_state_dict( | |
| model: nn.Module, | |
| states: Mapping[str:Any] | str, | |
| strip_prefix: str | None = None, | |
| filter_prefix: list[str] | str | None = None, | |
| strict: bool = True, | |
| verbose: bool = True, | |
| **kwargs, | |
| ): | |
| if isinstance(states, str): | |
| states = torch_load(states, **kwargs) | |
| if isinstance(strip_prefix, str) and strip_prefix[-1] != ".": | |
| strip_prefix += "." | |
| return load_state_dict( | |
| model, | |
| states.get("state_dict", states), | |
| strip_prefix=strip_prefix, | |
| strict=strict, | |
| filter_prefix=filter_prefix, | |
| verbose=verbose, | |
| ) | |