UVD / uvd /utils /module_utils.py
ryanhoangt's picture
Upload folder using huggingface_hub
c456c14 verified
from __future__ import annotations
from typing import Callable, Mapping, Any
import torch
import tree
from allenact.utils.system import get_logger
from torch import nn
from .file_utils import f_join
def freeze_module(module: nn.Module | torch.Tensor) -> nn.Module:
if torch.is_tensor(module):
module.requires_grad = False
return module
for param in module.parameters():
param.requires_grad = False
module.eval()
return module
def unfreeze_module(module: nn.Module | torch.Tensor) -> nn.Module:
if torch.is_tensor(module):
module.requires_grad = True
return module
for param in module.parameters():
param.requires_grad = True
module.train()
return module
def freeze_bn(module: nn.Module):
for mod in module.modules():
if "BatchNorm" in type(mod).__name__:
mod.momentum = 0.0
mod.eval()
return module
def replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule(".".join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
assert len(bn_list) == 0
return root_module
def bn_to_gn(
module: nn.Module,
group_ratio: int = 16,
device: str | int | torch.device | None = None,
):
return replace_submodules(
root_module=module,
predicate=lambda x: isinstance(x, (nn.BatchNorm2d, nn.BatchNorm1d)),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // group_ratio,
num_channels=x.num_features,
device=device,
),
)
def torch_load(*fpath: str, map_location="cpu") -> dict:
fpath = str(f_join(*fpath))
return torch.load(fpath, map_location=map_location)
def tree_value_at_path(obj, paths: Tuple):
try:
for p in paths:
obj = obj[p]
return obj
except Exception as e:
raise ValueError(f"{e}\n\n-- Incorrect nested path {paths} for object: {obj}.")
def implements_method(object, method: str):
"""
Returns:
True if object implements a method
"""
return hasattr(object, method) and callable(getattr(object, method))
def load_state_dict(
objects,
states,
strip_prefix: str | None = None,
strict: bool = True,
filter_prefix: list[str] | str | None = None,
verbose: bool = True,
):
"""
Args:
strict: objects and states must match exactly
strip_prefix: only match the keys that have the prefix, and strip it
"""
def _load(paths, obj):
if not implements_method(obj, "load_state_dict"):
raise ValueError(
f"Object {type(obj)} does not support load_state_dict() method"
)
try:
state = tree_value_at_path(states, paths)
except ValueError: # paths do not exist in `states` structure
if strict:
raise
else:
return
if strip_prefix:
assert isinstance(strip_prefix, str)
state = {
k[len(strip_prefix) :]: v
for k, v in state.items()
if k.startswith(strip_prefix)
}
if filter_prefix:
state = {
k: v
for k, v in state.items()
if all(
not k.startswith(p)
for p in (
(filter_prefix,)
if isinstance(filter_prefix, str)
else filter_prefix
)
)
}
if isinstance(obj, nn.Module):
return obj.load_state_dict(state, strict=strict)
else:
return obj.load_state_dict(state)
keys = tree.map_structure_with_path(_load, objects)
if not strict and verbose:
if keys.missing_keys:
get_logger().debug(
f'Missing key(s) in state_dict: {", ".join(keys.missing_keys)}'
)
if keys.unexpected_keys:
get_logger().debug(
f'Unexpected key(s) in state dict: {", ".join(keys.unexpected_keys)}'
)
return keys
def load_pl_state_dict(
model: nn.Module,
states: Mapping[str:Any] | str,
strip_prefix: str | None = None,
filter_prefix: list[str] | str | None = None,
strict: bool = True,
verbose: bool = True,
**kwargs,
):
if isinstance(states, str):
states = torch_load(states, **kwargs)
if isinstance(strip_prefix, str) and strip_prefix[-1] != ".":
strip_prefix += "."
return load_state_dict(
model,
states.get("state_dict", states),
strip_prefix=strip_prefix,
strict=strict,
filter_prefix=filter_prefix,
verbose=verbose,
)