File size: 5,745 Bytes
1cd928a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# Only for repeat_expand
import torch.nn.functional as F
import numpy as np
import torch
import yaml
import os
from typing import Optional, Union
import glob
import re
try:
from typing import Literal
except Exception:
from typing_extensions import Literal
def wav_pad(wav, multiple=200):
batch, seq_len = wav.shape
padded_len = ((seq_len + (multiple-1)) // multiple) * multiple
padded_wav = repeat_expand(wav, padded_len)
return padded_wav
def repeat_expand_2d(content, target_len, mode = 'left'):
# content : [h, t]
return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
def repeat_expand_3d(content, target_len, mode = 'left'):
# content : [B, h, t]
list_content = []
for i in range(content.shape[0]):
list_content.append(repeat_expand_2d(content[i], target_len, mode))
return torch.stack(list_content, dim=0)
def repeat_expand_2d_left(content, target_len):
# content : [h, t]
src_len = content.shape[-1]
target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
temp = torch.arange(src_len+1) * target_len / src_len
current_pos = 0
for i in range(target_len):
if i < temp[current_pos+1]:
target[:, i] = content[:, current_pos]
else:
current_pos += 1
target[:, i] = content[:, current_pos]
return target
# mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
# content : [h, t]
content = content[None,:,:]
target = F.interpolate(content,size=target_len,mode=mode)[0]
return target
def repeat_expand(
content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
):
"""Repeat content to target length.
This is a wrapper of torch.nn.functional.interpolate.
Args:
content (torch.Tensor): tensor
target_len (int): target length
mode (str, optional): interpolation mode. Defaults to "nearest".
Returns:
torch.Tensor: tensor
"""
ndim = content.ndim
if content.ndim == 1:
content = content[None, None]
elif content.ndim == 2:
content = content[None]
assert content.ndim == 3
is_np = isinstance(content, np.ndarray)
if is_np:
content = torch.from_numpy(content)
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
if is_np:
results = results.numpy()
if ndim == 1:
return results[0, 0]
elif ndim == 2:
return results[0]
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def load_config(config_path):
try:
with open(config_path, "r") as config:
args = yaml.safe_load(config)
args = DotDict(args)
return args
except:
raise ValueError
############ from controlspeech ################
def get_last_checkpoint(work_dir, steps=None):
checkpoint = None
last_ckpt_path = None
ckpt_paths = get_all_ckpts(work_dir, steps)
if len(ckpt_paths) > 0:
last_ckpt_path = ckpt_paths[0]
checkpoint = torch.load(last_ckpt_path, map_location='cpu')
return checkpoint, last_ckpt_path
def get_all_ckpts(work_dir, steps=None):
if steps is None:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
else:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
return sorted(glob.glob(ckpt_path_pattern),
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True):
if os.path.isfile(ckpt_base_dir):
base_dir = os.path.dirname(ckpt_base_dir)
ckpt_path = ckpt_base_dir
checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
else:
base_dir = ckpt_base_dir
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir)
if checkpoint is not None:
state_dict = checkpoint["state_dict"]
if len([k for k in state_dict.keys() if '.' in k]) > 0:
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
if k.startswith(f'{model_name}.')}
else:
if '.' not in model_name:
state_dict = state_dict[model_name]
else:
base_model_name = model_name.split('.')[0]
rest_model_name = model_name[len(base_model_name) + 1:]
state_dict = {
k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items()
if k.startswith(f'{rest_model_name}.')}
if not strict:
cur_model_state_dict = cur_model.state_dict()
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print("| Unmatched keys: ", key, new_param.shape, param.shape)
for key in unmatched_keys:
del state_dict[key]
cur_model.load_state_dict(state_dict, strict=strict)
print(f"| load '{model_name}' from '{ckpt_path}'.")
else:
e_msg = f"| ckpt not found in {base_dir}."
if force:
assert False, e_msg
else:
print(e_msg)
|