HQ-SVC / utils /utils.py
shawnpi's picture
Upload 753 files
1cd928a verified
# Only for repeat_expand
import torch.nn.functional as F
import numpy as np
import torch
import yaml
import os
from typing import Optional, Union
import glob
import re
try:
from typing import Literal
except Exception:
from typing_extensions import Literal
def wav_pad(wav, multiple=200):
batch, seq_len = wav.shape
padded_len = ((seq_len + (multiple-1)) // multiple) * multiple
padded_wav = repeat_expand(wav, padded_len)
return padded_wav
def repeat_expand_2d(content, target_len, mode = 'left'):
# content : [h, t]
return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
def repeat_expand_3d(content, target_len, mode = 'left'):
# content : [B, h, t]
list_content = []
for i in range(content.shape[0]):
list_content.append(repeat_expand_2d(content[i], target_len, mode))
return torch.stack(list_content, dim=0)
def repeat_expand_2d_left(content, target_len):
# content : [h, t]
src_len = content.shape[-1]
target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
temp = torch.arange(src_len+1) * target_len / src_len
current_pos = 0
for i in range(target_len):
if i < temp[current_pos+1]:
target[:, i] = content[:, current_pos]
else:
current_pos += 1
target[:, i] = content[:, current_pos]
return target
# mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
# content : [h, t]
content = content[None,:,:]
target = F.interpolate(content,size=target_len,mode=mode)[0]
return target
def repeat_expand(
content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
):
"""Repeat content to target length.
This is a wrapper of torch.nn.functional.interpolate.
Args:
content (torch.Tensor): tensor
target_len (int): target length
mode (str, optional): interpolation mode. Defaults to "nearest".
Returns:
torch.Tensor: tensor
"""
ndim = content.ndim
if content.ndim == 1:
content = content[None, None]
elif content.ndim == 2:
content = content[None]
assert content.ndim == 3
is_np = isinstance(content, np.ndarray)
if is_np:
content = torch.from_numpy(content)
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
if is_np:
results = results.numpy()
if ndim == 1:
return results[0, 0]
elif ndim == 2:
return results[0]
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def load_config(config_path):
try:
with open(config_path, "r") as config:
args = yaml.safe_load(config)
args = DotDict(args)
return args
except:
raise ValueError
############ from controlspeech ################
def get_last_checkpoint(work_dir, steps=None):
checkpoint = None
last_ckpt_path = None
ckpt_paths = get_all_ckpts(work_dir, steps)
if len(ckpt_paths) > 0:
last_ckpt_path = ckpt_paths[0]
checkpoint = torch.load(last_ckpt_path, map_location='cpu')
return checkpoint, last_ckpt_path
def get_all_ckpts(work_dir, steps=None):
if steps is None:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
else:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
return sorted(glob.glob(ckpt_path_pattern),
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True):
if os.path.isfile(ckpt_base_dir):
base_dir = os.path.dirname(ckpt_base_dir)
ckpt_path = ckpt_base_dir
checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
else:
base_dir = ckpt_base_dir
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir)
if checkpoint is not None:
state_dict = checkpoint["state_dict"]
if len([k for k in state_dict.keys() if '.' in k]) > 0:
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
if k.startswith(f'{model_name}.')}
else:
if '.' not in model_name:
state_dict = state_dict[model_name]
else:
base_model_name = model_name.split('.')[0]
rest_model_name = model_name[len(base_model_name) + 1:]
state_dict = {
k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items()
if k.startswith(f'{rest_model_name}.')}
if not strict:
cur_model_state_dict = cur_model.state_dict()
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print("| Unmatched keys: ", key, new_param.shape, param.shape)
for key in unmatched_keys:
del state_dict[key]
cur_model.load_state_dict(state_dict, strict=strict)
print(f"| load '{model_name}' from '{ckpt_path}'.")
else:
e_msg = f"| ckpt not found in {base_dir}."
if force:
assert False, e_msg
else:
print(e_msg)