Spaces:
Runtime error
Runtime error
| import collections | |
| import math | |
| import os | |
| import random | |
| import subprocess | |
| from socket import gethostname | |
| from typing import Any, Dict, Set, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from loguru import logger | |
| from torch import Tensor | |
| #from torch._six import string_classes | |
| from torch.autograd import Function | |
| from torch.types import Number | |
| from df_local.config import config | |
| from df_local.model import ModelParams | |
| try: | |
| from torchaudio.functional import resample as ta_resample | |
| except ImportError: | |
| from torchaudio.compliance.kaldi import resample_waveform as ta_resample # type: ignore | |
| def get_resample_params(method: str) -> Dict[str, Any]: | |
| params = { | |
| "sinc_fast": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 16}, | |
| "sinc_best": {"resampling_method": "sinc_interpolation", "lowpass_filter_width": 64}, | |
| "kaiser_fast": { | |
| "resampling_method": "kaiser_window", | |
| "lowpass_filter_width": 16, | |
| "rolloff": 0.85, | |
| "beta": 8.555504641634386, | |
| }, | |
| "kaiser_best": { | |
| "resampling_method": "kaiser_window", | |
| "lowpass_filter_width": 16, | |
| "rolloff": 0.9475937167399596, | |
| "beta": 14.769656459379492, | |
| }, | |
| } | |
| assert method in params.keys(), f"method must be one of {list(params.keys())}" | |
| return params[method] | |
| def resample(audio: Tensor, orig_sr: int, new_sr: int, method="sinc_fast"): | |
| params = get_resample_params(method) | |
| return ta_resample(audio, orig_sr, new_sr, **params) | |
| def get_device(): | |
| s = config("DEVICE", default="", section="train") | |
| if s == "": | |
| if torch.cuda.is_available(): | |
| DEVICE = torch.device("cuda:0") | |
| else: | |
| DEVICE = torch.device("cpu") | |
| else: | |
| DEVICE = torch.device(s) | |
| return DEVICE | |
| def as_complex(x: Tensor): | |
| if torch.is_complex(x): | |
| return x | |
| if x.shape[-1] != 2: | |
| raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}") | |
| if x.stride(-1) != 1: | |
| x = x.contiguous() | |
| return torch.view_as_complex(x) | |
| def as_real(x: Tensor): | |
| if torch.is_complex(x): | |
| return torch.view_as_real(x) | |
| return x | |
| class angle_re_im(Function): | |
| """Similar to torch.angle but robustify the gradient for zero magnitude.""" | |
| def forward(ctx, re: Tensor, im: Tensor): | |
| ctx.save_for_backward(re, im) | |
| return torch.atan2(im, re) | |
| def backward(ctx, grad: Tensor) -> Tuple[Tensor, Tensor]: | |
| re, im = ctx.saved_tensors | |
| grad_inv = grad / (re.square() + im.square()).clamp_min_(1e-10) | |
| return -im * grad_inv, re * grad_inv | |
| class angle(Function): | |
| """Similar to torch.angle but robustify the gradient for zero magnitude.""" | |
| def forward(ctx, x: Tensor): | |
| ctx.save_for_backward(x) | |
| return torch.atan2(x.imag, x.real) | |
| def backward(ctx, grad: Tensor): | |
| (x,) = ctx.saved_tensors | |
| grad_inv = grad / (x.real.square() + x.imag.square()).clamp_min_(1e-10) | |
| return torch.view_as_complex(torch.stack((-x.imag * grad_inv, x.real * grad_inv), dim=-1)) | |
| def check_finite_module(obj, name="Module", _raise=True) -> Set[str]: | |
| out: Set[str] = set() | |
| if isinstance(obj, torch.nn.Module): | |
| for name, child in obj.named_children(): | |
| out = out | check_finite_module(child, name) | |
| for name, param in obj.named_parameters(): | |
| out = out | check_finite_module(param, name) | |
| for name, buf in obj.named_buffers(): | |
| out = out | check_finite_module(buf, name) | |
| if _raise and len(out) > 0: | |
| raise ValueError(f"{name} not finite during checkpoint writing including: {out}") | |
| return out | |
| def make_np(x: Union[Tensor, np.ndarray, Number]) -> np.ndarray: | |
| """Transforms Tensor to numpy. | |
| Args: | |
| x: An instance of torch tensor or caffe blob name | |
| Returns: | |
| numpy.array: Numpy array | |
| """ | |
| if isinstance(x, np.ndarray): | |
| return x | |
| if np.isscalar(x): | |
| return np.array([x]) | |
| if isinstance(x, Tensor): | |
| return x.detach().cpu().numpy() | |
| raise NotImplementedError( | |
| "Got {}, but numpy array, scalar, or torch tensor are expected.".format(type(x)) | |
| ) | |
| def get_norm_alpha(log: bool = True) -> float: | |
| p = ModelParams() | |
| a_ = _calculate_norm_alpha(sr=p.sr, hop_size=p.hop_size, tau=p.norm_tau) | |
| precision = 3 | |
| a = 1.0 | |
| while a >= 1.0: | |
| a = round(a_, precision) | |
| precision += 1 | |
| if log: | |
| logger.info(f"Running with normalization window alpha = '{a}'") | |
| return a | |
| def _calculate_norm_alpha(sr: int, hop_size: int, tau: float): | |
| """Exponential decay factor alpha for a given tau (decay window size [s]).""" | |
| dt = hop_size / sr | |
| return math.exp(-dt / tau) | |
| def check_manual_seed(seed: int = None): | |
| """If manual seed is not specified, choose a random one and communicate it to the user.""" | |
| seed = seed or random.randint(1, 10000) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| return seed | |
| def get_git_root(): | |
| git_local_dir = os.path.dirname(os.path.abspath(__file__)) | |
| args = ["git", "-C", git_local_dir, "rev-parse", "--show-toplevel"] | |
| return subprocess.check_output(args).strip().decode() | |
| def get_commit_hash(): | |
| """Returns the current git commit.""" | |
| try: | |
| git_dir = get_git_root() | |
| args = ["git", "-C", git_dir, "rev-parse", "--short", "--verify", "HEAD"] | |
| commit = subprocess.check_output(args).strip().decode() | |
| except subprocess.CalledProcessError: | |
| # probably not in git repo | |
| commit = None | |
| return commit | |
| def get_host() -> str: | |
| return gethostname() | |
| def get_branch_name(): | |
| try: | |
| git_dir = os.path.dirname(os.path.abspath(__file__)) | |
| args = ["git", "-C", git_dir, "rev-parse", "--abbrev-ref", "HEAD"] | |
| branch = subprocess.check_output(args).strip().decode() | |
| except subprocess.CalledProcessError: | |
| # probably not in git repo | |
| branch = None | |
| return branch | |
| # from pytorch/ignite: | |
| def apply_to_tensor(input_, func): | |
| """Apply a function on a tensor or mapping, or sequence of tensors.""" | |
| if isinstance(input_, torch.nn.Module): | |
| return [apply_to_tensor(c, func) for c in input_.children()] | |
| elif isinstance(input_, torch.nn.Parameter): | |
| return func(input_.data) | |
| elif isinstance(input_, Tensor): | |
| return func(input_) | |
| elif isinstance(input_, str): | |
| return input_ | |
| elif isinstance(input_, collections.Mapping): | |
| return {k: apply_to_tensor(sample, func) for k, sample in input_.items()} | |
| elif isinstance(input_, collections.Iterable): | |
| return [apply_to_tensor(sample, func) for sample in input_] | |
| elif input_ is None: | |
| return input_ | |
| else: | |
| return input_ | |
| def detach_hidden(hidden: Any) -> Any: | |
| """Cut backpropagation graph. | |
| Auxillary function to cut the backpropagation graph by detaching the hidden | |
| vector. | |
| """ | |
| return apply_to_tensor(hidden, Tensor.detach) | |