File size: 2,070 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn.functional as F

from src.simulation.component import Component

################################################################################
# Pre-emphasis filter
################################################################################


class PreEmphasis(Component):
    """

    Apply pre-emphasis filter via waveform convolution. Adapted from

    https://github.com/clovaai/voxceleb_trainer/blob/master/utils.py

    """

    def __init__(self, coef: float = 0.97, method: str = 'shift'):
        """

        Initialize filter



        :param coef: pre-emphasis coefficient

        :param method: implementation; must be one of `conv` or `shift`

        """
        super().__init__()
        self.coef = coef

        if method not in ['conv', 'shift', None]:
            raise ValueError(f'Invalid method {method}')
        self.method = method

        # flip filter (cross-correlation --> convolution)
        self.register_buffer(
            'flipped_filter',
            torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, x: torch.Tensor):
        """

        Apply pre-emphasis filter via waveform convolution

        """

        assert x.ndim >= 2  # require batch dimension
        n_batch, signal_length = x.shape[0], x.shape[-1]

        # require channel dimension for convolution
        x = x.reshape(n_batch, -1, signal_length)
        in_channels = x.shape[1]

        if self.method == 'conv':

            # reflect padding to match lengths of in/out
            x = F.pad(x, (1, 0), 'reflect')
            return F.conv1d(
                x,
                self.flipped_filter.repeat(in_channels, 1, 1),
                groups=in_channels
            )

        elif self.method == 'shift':

            return torch.cat(
                [
                    x[..., 0:1],
                    x[..., 1:] - self.coef*x[..., :-1]
                ], dim=-1)

        else:
            return x