File size: 1,296 Bytes
c9f87fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import copy
import math
from typing import List

import numpy as np
import torch
from audiotools import AudioSignal
from audiotools import STFTParams
from audiotools.core.util import ensure_tensor
from audiotools.core.util import random_state
from audiotools.core.util import sample_from_dist
from audiotools.data.transforms import SpectralTransform
from numpy.random import RandomState

################################################################################
# Phase shift transform for encouraging robust rhythm feature extraction
################################################################################


class ShiftPhase(SpectralTransform):
    """
    Patch `audiotools.data.transforms.ShiftPhase` to allow processing on GPU
    """

    def __init__(
        self,
        shift: tuple = ("uniform", -np.pi, np.pi),
        name: str = None,
        prob: float = 1,
    ):
        super().__init__(name=name, prob=prob)
        self.shift = shift

    def _instantiate(self, state: RandomState):
        return {"shift": sample_from_dist(self.shift, state)}

    def _transform(self, signal, shift):
        shift = ensure_tensor(shift, ndim=signal.phase.ndim).to(signal.device)
        sig = signal.shift_phase(shift)
        sig.ensure_max_of_audio()
        return sig