hmr-dataset / genmo /utils /net_utils.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from pytorch_lightning.utilities.memory import recursive_detach
from scipy.ndimage._filters import _gaussian_kernel1d
from genmo.utils.pylogger import Log
def load_pretrained_model(model, ckpt_path):
"""
Load ckpt to model with strategy
"""
assert Path(ckpt_path).exists()
# use model's own load_pretrained_model method
if hasattr(model, "load_pretrained_model"):
ckpt = model.load_pretrained_model(ckpt_path)
else:
Log.info(f"Loading ckpt: {ckpt_path}")
ckpt = torch.load(ckpt_path, "cpu")
model.load_state_dict(ckpt, strict=True)
return ckpt
def find_last_ckpt_path(dirpath):
"""
Assume ckpt is named as e{}* or last*, following the convention of pytorch-lightning.
"""
assert dirpath is not None
dirpath = Path(dirpath)
assert dirpath.exists()
# Priority 1: last.ckpt
auto_last_ckpt_path = dirpath / "last.ckpt"
if auto_last_ckpt_path.exists():
return auto_last_ckpt_path
# Priority 2
model_paths = []
for p in sorted(list(dirpath.glob("*.ckpt"))):
if "last" in p.name:
continue
model_paths.append(p)
if len(model_paths) > 0:
return model_paths[-1]
else:
Log.info("No checkpoint found, set model_path to None")
return None
def get_resume_ckpt_path(resume_mode, ckpt_dir=None):
if Path(resume_mode).exists(): # This is a path
return resume_mode
assert resume_mode == "last"
return find_last_ckpt_path(ckpt_dir)
def select_state_dict_by_prefix(state_dict, prefix, new_prefix=""):
"""
For each weight that start with {old_prefix}, remove the {old_prefic} and form a new state_dict.
Args:
state_dict: dict
prefix: str
new_prefix: str, if exists, the new key will be {new_prefix} + {old_key[len(prefix):]}
Returns:
state_dict_new: dict
"""
state_dict_new = {}
for k in list(state_dict.keys()):
if k.startswith(prefix):
new_key = new_prefix + k[len(prefix) :]
state_dict_new[new_key] = state_dict[k]
return state_dict_new
def detach_to_cpu(in_dict):
return recursive_detach(in_dict, to_cpu=True)
def to_cuda(data):
"""Move data in the batch to cuda(), carefully handle data that is not tensor"""
if isinstance(data, torch.Tensor):
return data.cuda()
elif isinstance(data, dict):
return {k: to_cuda(v) for k, v in data.items()}
elif isinstance(data, list):
return [to_cuda(v) for v in data]
else:
return data
def get_valid_mask(max_len, valid_len, device="cpu"):
mask = torch.zeros(max_len, dtype=torch.bool).to(device)
mask[:valid_len] = True
return mask
def length_to_mask(lengths, max_len):
"""
Returns: (B, max_len)
"""
mask = torch.arange(max_len, device=lengths.device).expand(
len(lengths), max_len
) < lengths.unsqueeze(1)
return mask
def repeat_to_max_len(x, max_len, dim=0):
"""Repeat last frame to max_len along dim"""
assert isinstance(x, torch.Tensor)
if x.shape[dim] == max_len:
return x
elif x.shape[dim] < max_len:
x = x.clone()
x = x.transpose(0, dim)
x = torch.cat([x, repeat(x[-1:], "b ... -> (b r) ...", r=max_len - x.shape[0])])
x = x.transpose(0, dim)
return x
else:
raise ValueError(f"Unexpected length v.s. max_len: {x.shape[0]} v.s. {max_len}")
def repeat_to_max_len_dict(x_dict, max_len, dim=0):
for k, v in x_dict.items():
x_dict[k] = repeat_to_max_len(v, max_len, dim=dim)
return x_dict
class Transpose(nn.Module):
def __init__(self, dim1, dim2):
super(Transpose, self).__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x):
return x.transpose(self.dim1, self.dim2)
class GaussianSmooth(nn.Module):
def __init__(self, sigma=3, dim=-1):
super(GaussianSmooth, self).__init__()
kernel_smooth = _gaussian_kernel1d(
sigma=sigma, order=0, radius=int(4 * sigma + 0.5)
)
kernel_smooth = torch.from_numpy(kernel_smooth).float()[None, None] # (1, 1, K)
self.register_buffer("kernel_smooth", kernel_smooth, persistent=False)
self.dim = dim
def forward(self, x):
"""x (..., f, ...) f at dim"""
rad = self.kernel_smooth.size(-1) // 2
x = x.transpose(self.dim, -1)
x_shape = x.shape[:-1]
x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f)
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0]
x = F.conv1d(x, self.kernel_smooth)
x = x.squeeze(1).reshape(*x_shape, -1) # (..., f)
x = x.transpose(-1, self.dim)
return x
def gaussian_smooth(x, sigma=3, dim=-1):
kernel_smooth = _gaussian_kernel1d(
sigma=sigma, order=0, radius=int(4 * sigma + 0.5)
)
kernel_smooth = (
torch.from_numpy(kernel_smooth).float()[None, None].to(x)
) # (1, 1, K)
rad = kernel_smooth.size(-1) // 2
x = x.transpose(dim, -1)
x_shape = x.shape[:-1]
x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f)
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0]
x = F.conv1d(x, kernel_smooth)
x = x.squeeze(1).reshape(*x_shape, -1) # (..., f)
x = x.transpose(-1, dim)
return x
def moving_average_smooth(x, window_size=5, dim=-1):
kernel_smooth = torch.ones(window_size).float() / window_size
kernel_smooth = kernel_smooth[None, None].to(x) # (1, 1, window_size)
rad = kernel_smooth.size(-1) // 2
x = x.transpose(dim, -1)
x_shape = x.shape[:-1]
x = rearrange(x, "... f -> (...) 1 f") # (NB, 1, f)
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0]
x = F.conv1d(x, kernel_smooth)
x = x.squeeze(1).reshape(*x_shape, -1) # (..., f)
x = x.transpose(-1, dim)
return x