|
|
|
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
import os |
|
|
from typing import Optional, Union |
|
|
import glob |
|
|
import re |
|
|
|
|
|
try: |
|
|
from typing import Literal |
|
|
except Exception: |
|
|
from typing_extensions import Literal |
|
|
|
|
|
def wav_pad(wav, multiple=200): |
|
|
batch, seq_len = wav.shape |
|
|
padded_len = ((seq_len + (multiple-1)) // multiple) * multiple |
|
|
padded_wav = repeat_expand(wav, padded_len) |
|
|
return padded_wav |
|
|
|
|
|
def repeat_expand_2d(content, target_len, mode = 'left'): |
|
|
|
|
|
return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode) |
|
|
|
|
|
def repeat_expand_3d(content, target_len, mode = 'left'): |
|
|
|
|
|
list_content = [] |
|
|
for i in range(content.shape[0]): |
|
|
list_content.append(repeat_expand_2d(content[i], target_len, mode)) |
|
|
return torch.stack(list_content, dim=0) |
|
|
|
|
|
def repeat_expand_2d_left(content, target_len): |
|
|
|
|
|
|
|
|
src_len = content.shape[-1] |
|
|
target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device) |
|
|
temp = torch.arange(src_len+1) * target_len / src_len |
|
|
current_pos = 0 |
|
|
for i in range(target_len): |
|
|
if i < temp[current_pos+1]: |
|
|
target[:, i] = content[:, current_pos] |
|
|
else: |
|
|
current_pos += 1 |
|
|
target[:, i] = content[:, current_pos] |
|
|
|
|
|
return target |
|
|
|
|
|
|
|
|
|
|
|
def repeat_expand_2d_other(content, target_len, mode = 'nearest'): |
|
|
|
|
|
content = content[None,:,:] |
|
|
target = F.interpolate(content,size=target_len,mode=mode)[0] |
|
|
return target |
|
|
|
|
|
def repeat_expand( |
|
|
content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" |
|
|
): |
|
|
"""Repeat content to target length. |
|
|
This is a wrapper of torch.nn.functional.interpolate. |
|
|
|
|
|
Args: |
|
|
content (torch.Tensor): tensor |
|
|
target_len (int): target length |
|
|
mode (str, optional): interpolation mode. Defaults to "nearest". |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: tensor |
|
|
""" |
|
|
|
|
|
ndim = content.ndim |
|
|
|
|
|
if content.ndim == 1: |
|
|
content = content[None, None] |
|
|
elif content.ndim == 2: |
|
|
content = content[None] |
|
|
|
|
|
assert content.ndim == 3 |
|
|
|
|
|
is_np = isinstance(content, np.ndarray) |
|
|
if is_np: |
|
|
content = torch.from_numpy(content) |
|
|
|
|
|
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) |
|
|
|
|
|
if is_np: |
|
|
results = results.numpy() |
|
|
|
|
|
if ndim == 1: |
|
|
return results[0, 0] |
|
|
elif ndim == 2: |
|
|
return results[0] |
|
|
|
|
|
class DotDict(dict): |
|
|
def __getattr__(*args): |
|
|
val = dict.get(*args) |
|
|
return DotDict(val) if type(val) is dict else val |
|
|
|
|
|
__setattr__ = dict.__setitem__ |
|
|
__delattr__ = dict.__delitem__ |
|
|
|
|
|
def load_config(config_path): |
|
|
try: |
|
|
with open(config_path, "r") as config: |
|
|
args = yaml.safe_load(config) |
|
|
args = DotDict(args) |
|
|
return args |
|
|
except: |
|
|
raise ValueError |
|
|
|
|
|
|
|
|
def get_last_checkpoint(work_dir, steps=None): |
|
|
checkpoint = None |
|
|
last_ckpt_path = None |
|
|
ckpt_paths = get_all_ckpts(work_dir, steps) |
|
|
if len(ckpt_paths) > 0: |
|
|
last_ckpt_path = ckpt_paths[0] |
|
|
checkpoint = torch.load(last_ckpt_path, map_location='cpu') |
|
|
return checkpoint, last_ckpt_path |
|
|
|
|
|
|
|
|
def get_all_ckpts(work_dir, steps=None): |
|
|
if steps is None: |
|
|
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' |
|
|
else: |
|
|
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' |
|
|
return sorted(glob.glob(ckpt_path_pattern), |
|
|
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) |
|
|
|
|
|
|
|
|
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): |
|
|
if os.path.isfile(ckpt_base_dir): |
|
|
base_dir = os.path.dirname(ckpt_base_dir) |
|
|
ckpt_path = ckpt_base_dir |
|
|
checkpoint = torch.load(ckpt_base_dir, map_location='cpu') |
|
|
else: |
|
|
base_dir = ckpt_base_dir |
|
|
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) |
|
|
if checkpoint is not None: |
|
|
state_dict = checkpoint["state_dict"] |
|
|
if len([k for k in state_dict.keys() if '.' in k]) > 0: |
|
|
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() |
|
|
if k.startswith(f'{model_name}.')} |
|
|
else: |
|
|
if '.' not in model_name: |
|
|
state_dict = state_dict[model_name] |
|
|
else: |
|
|
base_model_name = model_name.split('.')[0] |
|
|
rest_model_name = model_name[len(base_model_name) + 1:] |
|
|
state_dict = { |
|
|
k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() |
|
|
if k.startswith(f'{rest_model_name}.')} |
|
|
if not strict: |
|
|
cur_model_state_dict = cur_model.state_dict() |
|
|
unmatched_keys = [] |
|
|
for key, param in state_dict.items(): |
|
|
if key in cur_model_state_dict: |
|
|
new_param = cur_model_state_dict[key] |
|
|
if new_param.shape != param.shape: |
|
|
unmatched_keys.append(key) |
|
|
print("| Unmatched keys: ", key, new_param.shape, param.shape) |
|
|
for key in unmatched_keys: |
|
|
del state_dict[key] |
|
|
cur_model.load_state_dict(state_dict, strict=strict) |
|
|
print(f"| load '{model_name}' from '{ckpt_path}'.") |
|
|
else: |
|
|
e_msg = f"| ckpt not found in {base_dir}." |
|
|
if force: |
|
|
assert False, e_msg |
|
|
else: |
|
|
print(e_msg) |
|
|
|