# Only for repeat_expand import torch.nn.functional as F import numpy as np import torch import yaml import os from typing import Optional, Union import glob import re try: from typing import Literal except Exception: from typing_extensions import Literal def wav_pad(wav, multiple=200): batch, seq_len = wav.shape padded_len = ((seq_len + (multiple-1)) // multiple) * multiple padded_wav = repeat_expand(wav, padded_len) return padded_wav def repeat_expand_2d(content, target_len, mode = 'left'): # content : [h, t] return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode) def repeat_expand_3d(content, target_len, mode = 'left'): # content : [B, h, t] list_content = [] for i in range(content.shape[0]): list_content.append(repeat_expand_2d(content[i], target_len, mode)) return torch.stack(list_content, dim=0) def repeat_expand_2d_left(content, target_len): # content : [h, t] src_len = content.shape[-1] target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device) temp = torch.arange(src_len+1) * target_len / src_len current_pos = 0 for i in range(target_len): if i < temp[current_pos+1]: target[:, i] = content[:, current_pos] else: current_pos += 1 target[:, i] = content[:, current_pos] return target # mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area' def repeat_expand_2d_other(content, target_len, mode = 'nearest'): # content : [h, t] content = content[None,:,:] target = F.interpolate(content,size=target_len,mode=mode)[0] return target def repeat_expand( content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" ): """Repeat content to target length. This is a wrapper of torch.nn.functional.interpolate. Args: content (torch.Tensor): tensor target_len (int): target length mode (str, optional): interpolation mode. Defaults to "nearest". Returns: torch.Tensor: tensor """ ndim = content.ndim if content.ndim == 1: content = content[None, None] elif content.ndim == 2: content = content[None] assert content.ndim == 3 is_np = isinstance(content, np.ndarray) if is_np: content = torch.from_numpy(content) results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) if is_np: results = results.numpy() if ndim == 1: return results[0, 0] elif ndim == 2: return results[0] class DotDict(dict): def __getattr__(*args): val = dict.get(*args) return DotDict(val) if type(val) is dict else val __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def load_config(config_path): try: with open(config_path, "r") as config: args = yaml.safe_load(config) args = DotDict(args) return args except: raise ValueError ############ from controlspeech ################ def get_last_checkpoint(work_dir, steps=None): checkpoint = None last_ckpt_path = None ckpt_paths = get_all_ckpts(work_dir, steps) if len(ckpt_paths) > 0: last_ckpt_path = ckpt_paths[0] checkpoint = torch.load(last_ckpt_path, map_location='cpu') return checkpoint, last_ckpt_path def get_all_ckpts(work_dir, steps=None): if steps is None: ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' else: ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' return sorted(glob.glob(ckpt_path_pattern), key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): if os.path.isfile(ckpt_base_dir): base_dir = os.path.dirname(ckpt_base_dir) ckpt_path = ckpt_base_dir checkpoint = torch.load(ckpt_base_dir, map_location='cpu') else: base_dir = ckpt_base_dir checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) if checkpoint is not None: state_dict = checkpoint["state_dict"] if len([k for k in state_dict.keys() if '.' in k]) > 0: state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() if k.startswith(f'{model_name}.')} else: if '.' not in model_name: state_dict = state_dict[model_name] else: base_model_name = model_name.split('.')[0] rest_model_name = model_name[len(base_model_name) + 1:] state_dict = { k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() if k.startswith(f'{rest_model_name}.')} if not strict: cur_model_state_dict = cur_model.state_dict() unmatched_keys = [] for key, param in state_dict.items(): if key in cur_model_state_dict: new_param = cur_model_state_dict[key] if new_param.shape != param.shape: unmatched_keys.append(key) print("| Unmatched keys: ", key, new_param.shape, param.shape) for key in unmatched_keys: del state_dict[key] cur_model.load_state_dict(state_dict, strict=strict) print(f"| load '{model_name}' from '{ckpt_path}'.") else: e_msg = f"| ckpt not found in {base_dir}." if force: assert False, e_msg else: print(e_msg)