| | from functools import reduce |
| | from inspect import isfunction |
| | from math import ceil, floor, log2, pi |
| | from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from torch import Generator, Tensor |
| | from typing_extensions import TypeGuard |
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | def exists(val: Optional[T]) -> TypeGuard[T]: |
| | return val is not None |
| |
|
| |
|
| | def iff(condition: bool, value: T) -> Optional[T]: |
| | return value if condition else None |
| |
|
| |
|
| | def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]: |
| | return isinstance(obj, list) or isinstance(obj, tuple) |
| |
|
| |
|
| | def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: |
| | if exists(val): |
| | return val |
| | return d() if isfunction(d) else d |
| |
|
| |
|
| | def to_list(val: Union[T, Sequence[T]]) -> List[T]: |
| | if isinstance(val, tuple): |
| | return list(val) |
| | if isinstance(val, list): |
| | return val |
| | return [val] |
| |
|
| |
|
| | def prod(vals: Sequence[int]) -> int: |
| | return reduce(lambda x, y: x * y, vals) |
| |
|
| |
|
| | def closest_power_2(x: float) -> int: |
| | exponent = log2(x) |
| | distance_fn = lambda z: abs(x - 2 ** z) |
| | exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) |
| | return 2 ** int(exponent_closest) |
| |
|
| |
|
| | """ |
| | Kwargs Utils |
| | """ |
| |
|
| |
|
| | def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: |
| | return_dicts: Tuple[Dict, Dict] = ({}, {}) |
| | for key in d.keys(): |
| | no_prefix = int(not key.startswith(prefix)) |
| | return_dicts[no_prefix][key] = d[key] |
| | return return_dicts |
| |
|
| |
|
| | def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: |
| | kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) |
| | if keep_prefix: |
| | return kwargs_with_prefix, kwargs |
| | kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} |
| | return kwargs_no_prefix, kwargs |
| |
|
| |
|
| | def prefix_dict(prefix: str, d: Dict) -> Dict: |
| | return {prefix + str(k): v for k, v in d.items()} |
| |
|
| |
|
| | """ |
| | DSP Utils |
| | """ |
| |
|
| |
|
| | def resample( |
| | waveforms: Tensor, |
| | factor_in: int, |
| | factor_out: int, |
| | rolloff: float = 0.99, |
| | lowpass_filter_width: int = 6, |
| | ) -> Tensor: |
| | """Resamples a waveform using sinc interpolation, adapted from torchaudio""" |
| | b, _, length = waveforms.shape |
| | length_target = int(factor_out * length / factor_in) |
| | d = dict(device=waveforms.device, dtype=waveforms.dtype) |
| |
|
| | base_factor = min(factor_in, factor_out) * rolloff |
| | width = ceil(lowpass_filter_width * factor_in / base_factor) |
| | idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in |
| | t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx |
| | t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * pi |
| |
|
| | window = torch.cos(t / lowpass_filter_width / 2) ** 2 |
| | scale = base_factor / factor_in |
| | kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t) |
| | kernels *= window * scale |
| |
|
| | waveforms = rearrange(waveforms, "b c t -> (b c) t") |
| | waveforms = F.pad(waveforms, (width, width + factor_in)) |
| | resampled = F.conv1d(waveforms[:, None], kernels, stride=factor_in) |
| | resampled = rearrange(resampled, "(b c) k l -> b c (l k)", b=b) |
| | return resampled[..., :length_target] |
| |
|
| |
|
| | def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: |
| | return resample(waveforms, factor_in=factor, factor_out=1, **kwargs) |
| |
|
| |
|
| | def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: |
| | return resample(waveforms, factor_in=1, factor_out=factor, **kwargs) |
| |
|
| |
|
| | """ Torch Utils """ |
| |
|
| |
|
| | def randn_like(tensor: Tensor, *args, generator: Optional[Generator] = None, **kwargs): |
| | """randn_like that supports generator""" |
| | return torch.randn(tensor.shape, *args, generator=generator, **kwargs).to(tensor) |
| |
|