Spaces:
Sleeping
Sleeping
| import copy | |
| import math | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| from audiotools import AudioSignal | |
| from audiotools import STFTParams | |
| from audiotools.core.util import ensure_tensor | |
| from audiotools.core.util import random_state | |
| from audiotools.core.util import sample_from_dist | |
| from audiotools.data.transforms import SpectralTransform | |
| from numpy.random import RandomState | |
| ################################################################################ | |
| # Phase shift transform for encouraging robust rhythm feature extraction | |
| ################################################################################ | |
| class ShiftPhase(SpectralTransform): | |
| """ | |
| Patch `audiotools.data.transforms.ShiftPhase` to allow processing on GPU | |
| """ | |
| def __init__( | |
| self, | |
| shift: tuple = ("uniform", -np.pi, np.pi), | |
| name: str = None, | |
| prob: float = 1, | |
| ): | |
| super().__init__(name=name, prob=prob) | |
| self.shift = shift | |
| def _instantiate(self, state: RandomState): | |
| return {"shift": sample_from_dist(self.shift, state)} | |
| def _transform(self, signal, shift): | |
| shift = ensure_tensor(shift, ndim=signal.phase.ndim).to(signal.device) | |
| sig = signal.shift_phase(shift) | |
| sig.ensure_max_of_audio() | |
| return sig | |