|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat |
|
|
from pytorch_lightning.utilities.memory import recursive_detach |
|
|
from scipy.ndimage._filters import _gaussian_kernel1d |
|
|
|
|
|
from genmo.utils.pylogger import Log |
|
|
|
|
|
|
|
|
def load_pretrained_model(model, ckpt_path): |
|
|
""" |
|
|
Load ckpt to model with strategy |
|
|
""" |
|
|
assert Path(ckpt_path).exists() |
|
|
|
|
|
if hasattr(model, "load_pretrained_model"): |
|
|
ckpt = model.load_pretrained_model(ckpt_path) |
|
|
else: |
|
|
Log.info(f"Loading ckpt: {ckpt_path}") |
|
|
ckpt = torch.load(ckpt_path, "cpu") |
|
|
model.load_state_dict(ckpt, strict=True) |
|
|
return ckpt |
|
|
|
|
|
|
|
|
def find_last_ckpt_path(dirpath): |
|
|
""" |
|
|
Assume ckpt is named as e{}* or last*, following the convention of pytorch-lightning. |
|
|
""" |
|
|
assert dirpath is not None |
|
|
dirpath = Path(dirpath) |
|
|
assert dirpath.exists() |
|
|
|
|
|
auto_last_ckpt_path = dirpath / "last.ckpt" |
|
|
if auto_last_ckpt_path.exists(): |
|
|
return auto_last_ckpt_path |
|
|
|
|
|
|
|
|
model_paths = [] |
|
|
for p in sorted(list(dirpath.glob("*.ckpt"))): |
|
|
if "last" in p.name: |
|
|
continue |
|
|
model_paths.append(p) |
|
|
if len(model_paths) > 0: |
|
|
return model_paths[-1] |
|
|
else: |
|
|
Log.info("No checkpoint found, set model_path to None") |
|
|
return None |
|
|
|
|
|
|
|
|
def get_resume_ckpt_path(resume_mode, ckpt_dir=None): |
|
|
if Path(resume_mode).exists(): |
|
|
return resume_mode |
|
|
assert resume_mode == "last" |
|
|
return find_last_ckpt_path(ckpt_dir) |
|
|
|
|
|
|
|
|
def select_state_dict_by_prefix(state_dict, prefix, new_prefix=""): |
|
|
""" |
|
|
For each weight that start with {old_prefix}, remove the {old_prefic} and form a new state_dict. |
|
|
Args: |
|
|
state_dict: dict |
|
|
prefix: str |
|
|
new_prefix: str, if exists, the new key will be {new_prefix} + {old_key[len(prefix):]} |
|
|
Returns: |
|
|
state_dict_new: dict |
|
|
""" |
|
|
state_dict_new = {} |
|
|
for k in list(state_dict.keys()): |
|
|
if k.startswith(prefix): |
|
|
new_key = new_prefix + k[len(prefix) :] |
|
|
state_dict_new[new_key] = state_dict[k] |
|
|
return state_dict_new |
|
|
|
|
|
|
|
|
def detach_to_cpu(in_dict): |
|
|
return recursive_detach(in_dict, to_cpu=True) |
|
|
|
|
|
|
|
|
def to_cuda(data): |
|
|
"""Move data in the batch to cuda(), carefully handle data that is not tensor""" |
|
|
if isinstance(data, torch.Tensor): |
|
|
return data.cuda() |
|
|
elif isinstance(data, dict): |
|
|
return {k: to_cuda(v) for k, v in data.items()} |
|
|
elif isinstance(data, list): |
|
|
return [to_cuda(v) for v in data] |
|
|
else: |
|
|
return data |
|
|
|
|
|
|
|
|
def get_valid_mask(max_len, valid_len, device="cpu"): |
|
|
mask = torch.zeros(max_len, dtype=torch.bool).to(device) |
|
|
mask[:valid_len] = True |
|
|
return mask |
|
|
|
|
|
|
|
|
def length_to_mask(lengths, max_len): |
|
|
""" |
|
|
Returns: (B, max_len) |
|
|
""" |
|
|
mask = torch.arange(max_len, device=lengths.device).expand( |
|
|
len(lengths), max_len |
|
|
) < lengths.unsqueeze(1) |
|
|
return mask |
|
|
|
|
|
|
|
|
def repeat_to_max_len(x, max_len, dim=0): |
|
|
"""Repeat last frame to max_len along dim""" |
|
|
assert isinstance(x, torch.Tensor) |
|
|
if x.shape[dim] == max_len: |
|
|
return x |
|
|
elif x.shape[dim] < max_len: |
|
|
x = x.clone() |
|
|
x = x.transpose(0, dim) |
|
|
x = torch.cat([x, repeat(x[-1:], "b ... -> (b r) ...", r=max_len - x.shape[0])]) |
|
|
x = x.transpose(0, dim) |
|
|
return x |
|
|
else: |
|
|
raise ValueError(f"Unexpected length v.s. max_len: {x.shape[0]} v.s. {max_len}") |
|
|
|
|
|
|
|
|
def repeat_to_max_len_dict(x_dict, max_len, dim=0): |
|
|
for k, v in x_dict.items(): |
|
|
x_dict[k] = repeat_to_max_len(v, max_len, dim=dim) |
|
|
return x_dict |
|
|
|
|
|
|
|
|
class Transpose(nn.Module): |
|
|
def __init__(self, dim1, dim2): |
|
|
super(Transpose, self).__init__() |
|
|
self.dim1 = dim1 |
|
|
self.dim2 = dim2 |
|
|
|
|
|
def forward(self, x): |
|
|
return x.transpose(self.dim1, self.dim2) |
|
|
|
|
|
|
|
|
class GaussianSmooth(nn.Module): |
|
|
def __init__(self, sigma=3, dim=-1): |
|
|
super(GaussianSmooth, self).__init__() |
|
|
kernel_smooth = _gaussian_kernel1d( |
|
|
sigma=sigma, order=0, radius=int(4 * sigma + 0.5) |
|
|
) |
|
|
kernel_smooth = torch.from_numpy(kernel_smooth).float()[None, None] |
|
|
self.register_buffer("kernel_smooth", kernel_smooth, persistent=False) |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, x): |
|
|
"""x (..., f, ...) f at dim""" |
|
|
rad = self.kernel_smooth.size(-1) // 2 |
|
|
|
|
|
x = x.transpose(self.dim, -1) |
|
|
x_shape = x.shape[:-1] |
|
|
x = rearrange(x, "... f -> (...) 1 f") |
|
|
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0] |
|
|
x = F.conv1d(x, self.kernel_smooth) |
|
|
x = x.squeeze(1).reshape(*x_shape, -1) |
|
|
x = x.transpose(-1, self.dim) |
|
|
return x |
|
|
|
|
|
|
|
|
def gaussian_smooth(x, sigma=3, dim=-1): |
|
|
kernel_smooth = _gaussian_kernel1d( |
|
|
sigma=sigma, order=0, radius=int(4 * sigma + 0.5) |
|
|
) |
|
|
kernel_smooth = ( |
|
|
torch.from_numpy(kernel_smooth).float()[None, None].to(x) |
|
|
) |
|
|
rad = kernel_smooth.size(-1) // 2 |
|
|
|
|
|
x = x.transpose(dim, -1) |
|
|
x_shape = x.shape[:-1] |
|
|
x = rearrange(x, "... f -> (...) 1 f") |
|
|
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0] |
|
|
x = F.conv1d(x, kernel_smooth) |
|
|
x = x.squeeze(1).reshape(*x_shape, -1) |
|
|
x = x.transpose(-1, dim) |
|
|
return x |
|
|
|
|
|
|
|
|
def moving_average_smooth(x, window_size=5, dim=-1): |
|
|
kernel_smooth = torch.ones(window_size).float() / window_size |
|
|
kernel_smooth = kernel_smooth[None, None].to(x) |
|
|
rad = kernel_smooth.size(-1) // 2 |
|
|
|
|
|
x = x.transpose(dim, -1) |
|
|
x_shape = x.shape[:-1] |
|
|
x = rearrange(x, "... f -> (...) 1 f") |
|
|
x = F.pad(x[None], (rad, rad, 0, 0), mode="replicate")[0] |
|
|
x = F.conv1d(x, kernel_smooth) |
|
|
x = x.squeeze(1).reshape(*x_shape, -1) |
|
|
x = x.transpose(-1, dim) |
|
|
return x |
|
|
|